perf: make embedder loading faster by 6x, and embed queries through the server
This commit is contained in:
@@ -16,8 +16,17 @@ import zmq
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import msgpack
|
import msgpack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import logging
|
||||||
|
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
# --- New Passage Loader from HNSW backend ---
|
# --- New Passage Loader from HNSW backend ---
|
||||||
@@ -169,7 +178,7 @@ def create_embedding_server_thread(
|
|||||||
在当前线程中创建并运行 embedding server
|
在当前线程中创建并运行 embedding server
|
||||||
这个函数设计为在单独的线程中调用
|
这个函数设计为在单独的线程中调用
|
||||||
"""
|
"""
|
||||||
print(f"INFO: Initializing embedding server thread on port {zmq_port}")
|
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查端口是否已被占用
|
# 检查端口是否已被占用
|
||||||
@@ -189,7 +198,7 @@ def create_embedding_server_thread(
|
|||||||
if embedding_mode == "mlx":
|
if embedding_mode == "mlx":
|
||||||
from leann.api import compute_embeddings_mlx
|
from leann.api import compute_embeddings_mlx
|
||||||
import torch
|
import torch
|
||||||
print("INFO: Using MLX for embeddings")
|
logger.info("Using MLX for embeddings")
|
||||||
# Set device to CPU for compatibility with DeviceTimer class
|
# Set device to CPU for compatibility with DeviceTimer class
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
cuda_available = False
|
cuda_available = False
|
||||||
@@ -197,7 +206,7 @@ def create_embedding_server_thread(
|
|||||||
elif embedding_mode == "openai":
|
elif embedding_mode == "openai":
|
||||||
from leann.api import compute_embeddings_openai
|
from leann.api import compute_embeddings_openai
|
||||||
import torch
|
import torch
|
||||||
print("INFO: Using OpenAI API for embeddings")
|
logger.info("Using OpenAI API for embeddings")
|
||||||
# Set device to CPU for compatibility with DeviceTimer class
|
# Set device to CPU for compatibility with DeviceTimer class
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
cuda_available = False
|
cuda_available = False
|
||||||
@@ -213,16 +222,16 @@ def create_embedding_server_thread(
|
|||||||
|
|
||||||
if cuda_available:
|
if cuda_available:
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
print("INFO: Using CUDA device")
|
logger.info("Using CUDA device")
|
||||||
elif mps_available:
|
elif mps_available:
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
print("INFO: Using MPS device (Apple Silicon)")
|
logger.info("Using MPS device (Apple Silicon)")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
print("INFO: Using CPU device")
|
logger.info("Using CPU device")
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
print(f"INFO: Loading model {model_name}")
|
logger.info(f"Loading model {model_name}")
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||||
|
|
||||||
# 优化模型
|
# 优化模型
|
||||||
@@ -230,7 +239,7 @@ def create_embedding_server_thread(
|
|||||||
try:
|
try:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
print(f"INFO: Using FP16 precision with model: {model_name}")
|
logger.info(f"Using FP16 precision with model: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
print(f"WARNING: Model optimization failed: {e}")
|
||||||
else:
|
else:
|
||||||
@@ -256,7 +265,7 @@ def create_embedding_server_thread(
|
|||||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
||||||
passages = SimplePassageLoader()
|
passages = SimplePassageLoader()
|
||||||
|
|
||||||
print(f"INFO: Loaded {len(passages)} passages.")
|
logger.info(f"Loaded {len(passages)} passages.")
|
||||||
|
|
||||||
def client_warmup(zmq_port):
|
def client_warmup(zmq_port):
|
||||||
"""Perform client-side warmup for DiskANN server"""
|
"""Perform client-side warmup for DiskANN server"""
|
||||||
@@ -365,7 +374,7 @@ def create_embedding_server_thread(
|
|||||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
||||||
"""处理文本批次"""
|
"""处理文本批次"""
|
||||||
batch_size = len(texts_batch)
|
batch_size = len(texts_batch)
|
||||||
print(f"INFO: Processing batch of size {batch_size}")
|
logger.info(f"Processing batch of size {batch_size}")
|
||||||
|
|
||||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
||||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
||||||
|
|||||||
@@ -18,10 +18,19 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Optional, Union
|
from typing import Dict, Any, Optional, Union
|
||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
RED = "\033[91m"
|
RED = "\033[91m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_similarity_metric():
|
def is_similarity_metric():
|
||||||
"""
|
"""
|
||||||
@@ -36,6 +45,23 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# Timing utilities
|
||||||
|
@contextmanager
|
||||||
|
def timer(name: str, sync_cuda: bool = True):
|
||||||
|
"""Context manager for timing operations with optional CUDA sync"""
|
||||||
|
start_time = time.time()
|
||||||
|
if sync_cuda and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if sync_cuda and torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elif sync_cuda and torch.backends.mps.is_available():
|
||||||
|
torch.mps.synchronize()
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info(f"⏱️ {name}: {elapsed:.4f}s")
|
||||||
|
|
||||||
|
|
||||||
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||||
@@ -120,13 +146,13 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|||||||
if passage_data and passage_data.get("text"):
|
if passage_data and passage_data.get("text"):
|
||||||
return {"text": passage_data["text"]}
|
return {"text": passage_data["text"]}
|
||||||
else:
|
else:
|
||||||
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}")
|
logger.debug(f"Empty text for ID {int_id} -> {string_id}")
|
||||||
return {"text": ""}
|
return {"text": ""}
|
||||||
else:
|
else:
|
||||||
print(f"DEBUG: ID {int_id} not found in label_map")
|
logger.debug(f"ID {int_id} not found in label_map")
|
||||||
return {"text": ""}
|
return {"text": ""}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"DEBUG: Exception getting passage {passage_id}: {e}")
|
logger.debug(f"Exception getting passage {passage_id}: {e}")
|
||||||
return {"text": ""}
|
return {"text": ""}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@@ -184,8 +210,21 @@ def create_hnsw_embedding_server(
|
|||||||
tokenizer = None # MLX handles tokenization separately
|
tokenizer = None # MLX handles tokenization separately
|
||||||
else: # sentence-transformers
|
else: # sentence-transformers
|
||||||
print(f"Loading tokenizer for {model_name}...")
|
print(f"Loading tokenizer for {model_name}...")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
# Optimized tokenizer loading: try local first, then fallback
|
||||||
print(f"Tokenizer loaded successfully!")
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
use_fast=True, # Use fast tokenizer (better runtime perf)
|
||||||
|
local_files_only=True # Avoid network delays
|
||||||
|
)
|
||||||
|
print(f"Tokenizer loaded successfully! (local + fast)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Local tokenizer failed ({e}), trying network download...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
use_fast=True # Use fast tokenizer
|
||||||
|
)
|
||||||
|
print(f"Tokenizer loaded successfully! (network)")
|
||||||
|
|
||||||
# Device setup
|
# Device setup
|
||||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
@@ -217,9 +256,47 @@ def create_hnsw_embedding_server(
|
|||||||
print("OpenAI API mode - no local model loading required")
|
print("OpenAI API mode - no local model loading required")
|
||||||
model = None
|
model = None
|
||||||
else:
|
else:
|
||||||
# Use standard transformers for sentence-transformers models
|
# Use optimized transformers loading for sentence-transformers models
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
print(f"Loading model with optimizations...")
|
||||||
print(f"Model {model_name} loaded successfully!")
|
try:
|
||||||
|
# Ultra-fast loading: preload config + fast_init
|
||||||
|
from transformers import AutoConfig
|
||||||
|
config = AutoConfig.from_pretrained(model_name, local_files_only=True)
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
config=config,
|
||||||
|
torch_dtype=torch.float16, # Half precision for speed
|
||||||
|
low_cpu_mem_usage=True, # Reduce memory peaks
|
||||||
|
local_files_only=True, # Avoid network delays
|
||||||
|
_fast_init=True # Skip weight init checks
|
||||||
|
).to(device).eval()
|
||||||
|
print(f"Model {model_name} loaded successfully! (ultra-fast)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ultra-fast loading failed ({e}), trying optimized...")
|
||||||
|
try:
|
||||||
|
# Fallback: regular optimized loading
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
local_files_only=True
|
||||||
|
).to(device).eval()
|
||||||
|
print(f"Model {model_name} loaded successfully! (optimized)")
|
||||||
|
except Exception as e2:
|
||||||
|
print(f"Optimized loading failed ({e2}), trying network...")
|
||||||
|
try:
|
||||||
|
# Fallback: optimized network loading
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True
|
||||||
|
).to(device).eval()
|
||||||
|
print(f"Model {model_name} loaded successfully! (network + optimized)")
|
||||||
|
except Exception as e3:
|
||||||
|
print(f"All optimized methods failed ({e3}), using standard...")
|
||||||
|
# Final fallback: standard loading
|
||||||
|
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||||
|
print(f"Model {model_name} loaded successfully! (standard)")
|
||||||
|
|
||||||
# Check port availability
|
# Check port availability
|
||||||
import socket
|
import socket
|
||||||
@@ -370,8 +447,9 @@ def create_hnsw_embedding_server(
|
|||||||
if embedding_mode == "mlx":
|
if embedding_mode == "mlx":
|
||||||
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
|
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
|
||||||
elif embedding_mode == "openai":
|
elif embedding_mode == "openai":
|
||||||
from leann.api import compute_embeddings_openai
|
with timer("OpenAI API call", sync_cuda=False):
|
||||||
return compute_embeddings_openai(texts_batch, model_name)
|
from leann.api import compute_embeddings_openai
|
||||||
|
return compute_embeddings_openai(texts_batch, model_name)
|
||||||
|
|
||||||
_is_e5_model = "e5" in model_name.lower()
|
_is_e5_model = "e5" in model_name.lower()
|
||||||
_is_bge_model = "bge" in model_name.lower()
|
_is_bge_model = "bge" in model_name.lower()
|
||||||
@@ -417,44 +495,46 @@ def create_hnsw_embedding_server(
|
|||||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with embed_timer.timing():
|
with timer("Model forward pass"):
|
||||||
out = model(enc["input_ids"], enc["attention_mask"])
|
with embed_timer.timing():
|
||||||
|
out = model(enc["input_ids"], enc["attention_mask"])
|
||||||
|
|
||||||
with pool_timer.timing():
|
with timer("Pooling"):
|
||||||
if _is_bge_model:
|
with pool_timer.timing():
|
||||||
pooled_embeddings = out.last_hidden_state[:, 0]
|
if _is_bge_model:
|
||||||
elif not hasattr(out, "last_hidden_state"):
|
pooled_embeddings = out.last_hidden_state[:, 0]
|
||||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
elif not hasattr(out, "last_hidden_state"):
|
||||||
pooled_embeddings = out
|
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||||
|
pooled_embeddings = out
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
|
||||||
|
)
|
||||||
|
hidden_dim = getattr(
|
||||||
|
model.config, "hidden_size", 384 if _is_e5_model else 768
|
||||||
|
)
|
||||||
|
pooled_embeddings = torch.zeros(
|
||||||
|
(batch_size, hidden_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=enc["input_ids"].dtype
|
||||||
|
if hasattr(enc["input_ids"], "dtype")
|
||||||
|
else torch.float32,
|
||||||
|
)
|
||||||
|
elif _is_e5_model:
|
||||||
|
pooled_embeddings = e5_average_pool(
|
||||||
|
out.last_hidden_state, enc["attention_mask"]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(
|
hidden_states = out.last_hidden_state
|
||||||
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
|
mask_expanded = (
|
||||||
|
enc["attention_mask"]
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand(hidden_states.size())
|
||||||
|
.float()
|
||||||
)
|
)
|
||||||
hidden_dim = getattr(
|
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||||
model.config, "hidden_size", 384 if _is_e5_model else 768
|
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||||
)
|
pooled_embeddings = sum_embeddings / sum_mask
|
||||||
pooled_embeddings = torch.zeros(
|
|
||||||
(batch_size, hidden_dim),
|
|
||||||
device=device,
|
|
||||||
dtype=enc["input_ids"].dtype
|
|
||||||
if hasattr(enc["input_ids"], "dtype")
|
|
||||||
else torch.float32,
|
|
||||||
)
|
|
||||||
elif _is_e5_model:
|
|
||||||
pooled_embeddings = e5_average_pool(
|
|
||||||
out.last_hidden_state, enc["attention_mask"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = out.last_hidden_state
|
|
||||||
mask_expanded = (
|
|
||||||
enc["attention_mask"]
|
|
||||||
.unsqueeze(-1)
|
|
||||||
.expand(hidden_states.size())
|
|
||||||
.float()
|
|
||||||
)
|
|
||||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
|
||||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
|
||||||
pooled_embeddings = sum_embeddings / sum_mask
|
|
||||||
|
|
||||||
final_embeddings = pooled_embeddings
|
final_embeddings = pooled_embeddings
|
||||||
if _is_e5_model or _is_bge_model:
|
if _is_e5_model or _is_bge_model:
|
||||||
@@ -536,7 +616,7 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
def zmq_server_thread():
|
def zmq_server_thread():
|
||||||
"""ZMQ server thread"""
|
"""ZMQ server thread"""
|
||||||
nonlocal passages, model, tokenizer, model_name
|
nonlocal passages, model, tokenizer, model_name, embedding_mode
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REP)
|
socket = context.socket(zmq.REP)
|
||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
@@ -556,13 +636,13 @@ def create_hnsw_embedding_server(
|
|||||||
try:
|
try:
|
||||||
request_payload = msgpack.unpackb(message_bytes)
|
request_payload = msgpack.unpackb(message_bytes)
|
||||||
if isinstance(request_payload, list):
|
if isinstance(request_payload, list):
|
||||||
print(f"DEBUG: request_payload length: {len(request_payload)}")
|
logger.debug(f"request_payload length: {len(request_payload)}")
|
||||||
for i, item in enumerate(request_payload):
|
for i, item in enumerate(request_payload):
|
||||||
print(
|
print(
|
||||||
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
|
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle control messages for meta path and model management
|
# Handle control messages for meta path and model management FIRST
|
||||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
||||||
if request_payload[0] == "__QUERY_META_PATH__":
|
if request_payload[0] == "__QUERY_META_PATH__":
|
||||||
# Return the current meta path being used by the server
|
# Return the current meta path being used by the server
|
||||||
@@ -617,19 +697,61 @@ def create_hnsw_embedding_server(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Clean up old model to free memory
|
# Clean up old model to free memory
|
||||||
print("INFO: Releasing old model from memory...")
|
logger.info("Releasing old model from memory...")
|
||||||
old_model = model
|
old_model = model
|
||||||
old_tokenizer = tokenizer
|
old_tokenizer = tokenizer
|
||||||
|
|
||||||
# Load new tokenizer first
|
# Load new tokenizer first (optimized)
|
||||||
print(f"Loading new tokenizer for {new_model_name}...")
|
print(f"Loading new tokenizer for {new_model_name}...")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
try:
|
||||||
new_model_name, use_fast=True
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
)
|
new_model_name,
|
||||||
|
use_fast=True,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
print(f"New tokenizer loaded! (local + fast)")
|
||||||
|
except:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
new_model_name,
|
||||||
|
use_fast=True
|
||||||
|
)
|
||||||
|
print(f"New tokenizer loaded! (network + fast)")
|
||||||
|
|
||||||
# Load new model
|
# Load new model (optimized)
|
||||||
print(f"Loading new model {new_model_name}...")
|
print(f"Loading new model {new_model_name}...")
|
||||||
model = AutoModel.from_pretrained(new_model_name)
|
try:
|
||||||
|
# Ultra-fast model switching
|
||||||
|
from transformers import AutoConfig
|
||||||
|
config = AutoConfig.from_pretrained(new_model_name, local_files_only=True)
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
new_model_name,
|
||||||
|
config=config,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
local_files_only=True,
|
||||||
|
_fast_init=True
|
||||||
|
)
|
||||||
|
print(f"New model loaded! (ultra-fast)")
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
new_model_name,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
print(f"New model loaded! (optimized)")
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
model = AutoModel.from_pretrained(
|
||||||
|
new_model_name,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True
|
||||||
|
)
|
||||||
|
print(f"New model loaded! (network + optimized)")
|
||||||
|
except:
|
||||||
|
model = AutoModel.from_pretrained(new_model_name)
|
||||||
|
print(f"New model loaded! (standard)")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@@ -640,19 +762,27 @@ def create_hnsw_embedding_server(
|
|||||||
# Clear GPU cache if available
|
# Clear GPU cache if available
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
print("INFO: Cleared CUDA cache")
|
logger.info("Cleared CUDA cache")
|
||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
print("INFO: Cleared MPS cache")
|
logger.info("Cleared MPS cache")
|
||||||
|
|
||||||
# Update model name
|
# Update model name
|
||||||
model_name = new_model_name
|
model_name = new_model_name
|
||||||
|
|
||||||
|
# Re-detect embedding mode based on new model name
|
||||||
|
if model_name.startswith("text-embedding-"):
|
||||||
|
embedding_mode = "openai"
|
||||||
|
logger.info(f"Auto-detected embedding mode: openai for {model_name}")
|
||||||
|
else:
|
||||||
|
embedding_mode = "sentence-transformers"
|
||||||
|
logger.info(f"Auto-detected embedding mode: sentence-transformers for {model_name}")
|
||||||
|
|
||||||
# Force garbage collection
|
# Force garbage collection
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
print("INFO: Memory cleanup completed")
|
logger.info("Memory cleanup completed")
|
||||||
|
|
||||||
response = ["SUCCESS"]
|
response = ["SUCCESS"]
|
||||||
print(
|
print(
|
||||||
@@ -664,6 +794,32 @@ def create_hnsw_embedding_server(
|
|||||||
socket.send(msgpack.packb(response))
|
socket.send(msgpack.packb(response))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Handle direct text embedding request (for OpenAI and sentence-transformers)
|
||||||
|
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||||
|
# Check if this is a direct text request (list of strings) and NOT a control message
|
||||||
|
if (all(isinstance(item, str) for item in request_payload) and
|
||||||
|
not request_payload[0].startswith("__")):
|
||||||
|
logger.info(f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if embedding_mode == "openai":
|
||||||
|
from leann.api import compute_embeddings_openai
|
||||||
|
embeddings = compute_embeddings_openai(request_payload, model_name)
|
||||||
|
else:
|
||||||
|
# sentence-transformers mode - compute directly
|
||||||
|
with timer(f"Direct text embedding ({len(request_payload)} texts)"):
|
||||||
|
embeddings = process_batch(request_payload, [], [])
|
||||||
|
|
||||||
|
response = embeddings.tolist()
|
||||||
|
socket.send(msgpack.packb(response))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"ERROR: Failed to compute {embedding_mode} embeddings: {e}")
|
||||||
|
socket.send(msgpack.packb([]))
|
||||||
|
continue
|
||||||
|
|
||||||
# Handle distance calculation requests
|
# Handle distance calculation requests
|
||||||
if (
|
if (
|
||||||
isinstance(request_payload, list)
|
isinstance(request_payload, list)
|
||||||
@@ -674,7 +830,7 @@ def create_hnsw_embedding_server(
|
|||||||
node_ids = request_payload[0]
|
node_ids = request_payload[0]
|
||||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||||
|
|
||||||
print("DEBUG: Distance calculation request received")
|
logger.debug("Distance calculation request received")
|
||||||
print(f" Node IDs: {node_ids}")
|
print(f" Node IDs: {node_ids}")
|
||||||
print(f" Query vector dim: {len(query_vector)}")
|
print(f" Query vector dim: {len(query_vector)}")
|
||||||
print(f" Passages loaded: {len(passages)}")
|
print(f" Passages loaded: {len(passages)}")
|
||||||
@@ -684,7 +840,7 @@ def create_hnsw_embedding_server(
|
|||||||
missing_ids = []
|
missing_ids = []
|
||||||
with lookup_timer.timing():
|
with lookup_timer.timing():
|
||||||
for nid in node_ids:
|
for nid in node_ids:
|
||||||
print(f"DEBUG: Looking up passage ID {nid}")
|
logger.debug(f"Looking up passage ID {nid}")
|
||||||
try:
|
try:
|
||||||
txtinfo = passages[nid]
|
txtinfo = passages[nid]
|
||||||
if txtinfo is None:
|
if txtinfo is None:
|
||||||
@@ -804,29 +960,11 @@ def create_hnsw_embedding_server(
|
|||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
print(
|
logger.info(
|
||||||
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds"
|
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Handle direct text embedding request (for OpenAI mode)
|
|
||||||
if embedding_mode == "openai" and isinstance(request_payload, list) and len(request_payload) > 0:
|
|
||||||
# Check if this is a direct text request (list of strings)
|
|
||||||
if all(isinstance(item, str) for item in request_payload):
|
|
||||||
print(f"Processing direct text embedding request for {len(request_payload)} texts")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from leann.api import compute_embeddings_openai
|
|
||||||
embeddings = compute_embeddings_openai(request_payload, model_name)
|
|
||||||
response = embeddings.tolist()
|
|
||||||
socket.send(msgpack.packb(response))
|
|
||||||
e2e_end = time.time()
|
|
||||||
print(f"Text embedding E2E time: {e2e_end - e2e_start:.6f} seconds")
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to compute OpenAI embeddings: {e}")
|
|
||||||
socket.send(msgpack.packb([]))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Standard embedding request (passage ID lookup)
|
# Standard embedding request (passage ID lookup)
|
||||||
if (
|
if (
|
||||||
@@ -945,10 +1083,10 @@ def create_hnsw_embedding_server(
|
|||||||
elif device.type == "mps":
|
elif device.type == "mps":
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
print("ZMQ socket timeout, continuing to listen")
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in ZMQ server loop: {e}")
|
print(f"Error in ZMQ server loop: {e}")
|
||||||
|
|||||||
@@ -51,7 +51,63 @@ def compute_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray:
|
||||||
"""Computes embeddings using sentence-transformers library."""
|
"""Computes embeddings using sentence-transformers via embedding server."""
|
||||||
|
print(
|
||||||
|
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use embedding server for sentence-transformers too
|
||||||
|
# This avoids loading the model twice (once in API, once in server)
|
||||||
|
try:
|
||||||
|
# Import ZMQ client functionality and server manager
|
||||||
|
import zmq
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
from .embedding_server_manager import EmbeddingServerManager
|
||||||
|
|
||||||
|
# Ensure embedding server is running
|
||||||
|
port = 5557
|
||||||
|
server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server")
|
||||||
|
|
||||||
|
server_started = server_manager.start_server(
|
||||||
|
port=port,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
enable_warmup=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not server_started:
|
||||||
|
raise RuntimeError(f"Failed to start embedding server on port {port}")
|
||||||
|
|
||||||
|
# Connect to embedding server
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REQ)
|
||||||
|
socket.connect(f"tcp://localhost:{port}")
|
||||||
|
|
||||||
|
# Send chunks to server for embedding computation
|
||||||
|
request = chunks
|
||||||
|
socket.send(msgpack.packb(request))
|
||||||
|
|
||||||
|
# Receive embeddings from server
|
||||||
|
response = socket.recv()
|
||||||
|
embeddings_list = msgpack.unpackb(response)
|
||||||
|
|
||||||
|
# Convert back to numpy array
|
||||||
|
embeddings = np.array(embeddings_list, dtype=np.float32)
|
||||||
|
|
||||||
|
socket.close()
|
||||||
|
context.term()
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback to direct sentence-transformers if server connection fails
|
||||||
|
print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}")
|
||||||
|
return _compute_embeddings_sentence_transformers_direct(chunks, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray:
|
||||||
|
"""Direct sentence-transformers computation (fallback)."""
|
||||||
try:
|
try:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@@ -64,7 +120,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str)
|
|||||||
|
|
||||||
model = model.half()
|
model = model.half()
|
||||||
print(
|
print(
|
||||||
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..."
|
f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..."
|
||||||
)
|
)
|
||||||
# use acclerater GPU or MAC GPU
|
# use acclerater GPU or MAC GPU
|
||||||
|
|
||||||
|
|||||||
@@ -200,7 +200,27 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
# Check model compatibility
|
# Check model compatibility
|
||||||
model_matches = _check_server_model(self.server_port, model_name)
|
model_matches = _check_server_model(self.server_port, model_name)
|
||||||
if not model_matches:
|
if model_matches:
|
||||||
|
print(
|
||||||
|
f"✅ Existing server already using correct model: {model_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Still check meta path if provided
|
||||||
|
passages_file = kwargs.get("passages_file")
|
||||||
|
if passages_file and str(passages_file).endswith(
|
||||||
|
".meta.json"
|
||||||
|
):
|
||||||
|
meta_matches = _check_server_meta_path(
|
||||||
|
self.server_port, str(passages_file)
|
||||||
|
)
|
||||||
|
if not meta_matches:
|
||||||
|
print("⚠️ Updating meta path to: {passages_file}")
|
||||||
|
_update_server_meta_path(
|
||||||
|
self.server_port, str(passages_file)
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
print(
|
print(
|
||||||
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
f"⚠️ Existing server has different model. Attempting to update to: {model_name}"
|
||||||
)
|
)
|
||||||
@@ -230,11 +250,6 @@ class EmbeddingServerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"✅ Existing server already using correct model: {model_name}"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
else:
|
else:
|
||||||
# Server process exists but port not responding - restart
|
# Server process exists but port not responding - restart
|
||||||
print("⚠️ Server process exists but not responding. Restarting...")
|
print("⚠️ Server process exists but not responding. Restarting...")
|
||||||
@@ -254,7 +269,11 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
# Check model compatibility first
|
# Check model compatibility first
|
||||||
model_matches = _check_server_model(port, model_name)
|
model_matches = _check_server_model(port, model_name)
|
||||||
if not model_matches:
|
if model_matches:
|
||||||
|
print(
|
||||||
|
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
print(
|
print(
|
||||||
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}"
|
||||||
)
|
)
|
||||||
@@ -263,10 +282,6 @@ class EmbeddingServerManager:
|
|||||||
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
f"❌ Failed to update server model to {model_name}. Consider using a different port."
|
||||||
)
|
)
|
||||||
print(f"✅ Successfully updated server model to: {model_name}")
|
print(f"✅ Successfully updated server model to: {model_name}")
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"✅ Existing server on port {port} is using correct model: {model_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check meta path compatibility if provided
|
# Check meta path compatibility if provided
|
||||||
if passages_file and str(passages_file).endswith(".meta.json"):
|
if passages_file and str(passages_file).endswith(".meta.json"):
|
||||||
|
|||||||
Reference in New Issue
Block a user