diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index de58e3a..e3f719a 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -16,8 +16,17 @@ import zmq import numpy as np import msgpack from pathlib import Path +import logging RED = "\033[91m" + +# Set up logging based on environment variable +LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper() +logging.basicConfig( + level=getattr(logging, LOG_LEVEL, logging.INFO), + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) RESET = "\033[0m" # --- New Passage Loader from HNSW backend --- @@ -169,7 +178,7 @@ def create_embedding_server_thread( 在当前线程中创建并运行 embedding server 这个函数设计为在单独的线程中调用 """ - print(f"INFO: Initializing embedding server thread on port {zmq_port}") + logger.info(f"Initializing embedding server thread on port {zmq_port}") try: # 检查端口是否已被占用 @@ -189,7 +198,7 @@ def create_embedding_server_thread( if embedding_mode == "mlx": from leann.api import compute_embeddings_mlx import torch - print("INFO: Using MLX for embeddings") + logger.info("Using MLX for embeddings") # Set device to CPU for compatibility with DeviceTimer class device = torch.device("cpu") cuda_available = False @@ -197,7 +206,7 @@ def create_embedding_server_thread( elif embedding_mode == "openai": from leann.api import compute_embeddings_openai import torch - print("INFO: Using OpenAI API for embeddings") + logger.info("Using OpenAI API for embeddings") # Set device to CPU for compatibility with DeviceTimer class device = torch.device("cpu") cuda_available = False @@ -213,16 +222,16 @@ def create_embedding_server_thread( if cuda_available: device = torch.device("cuda") - print("INFO: Using CUDA device") + logger.info("Using CUDA device") elif mps_available: device = torch.device("mps") - print("INFO: Using MPS device (Apple Silicon)") + logger.info("Using MPS device (Apple Silicon)") else: device = torch.device("cpu") - print("INFO: Using CPU device") + logger.info("Using CPU device") # 加载模型 - print(f"INFO: Loading model {model_name}") + logger.info(f"Loading model {model_name}") model = AutoModel.from_pretrained(model_name).to(device).eval() # 优化模型 @@ -230,7 +239,7 @@ def create_embedding_server_thread( try: model = model.half() model = torch.compile(model) - print(f"INFO: Using FP16 precision with model: {model_name}") + logger.info(f"Using FP16 precision with model: {model_name}") except Exception as e: print(f"WARNING: Model optimization failed: {e}") else: @@ -256,7 +265,7 @@ def create_embedding_server_thread( print("WARNING: No passages file provided or file not found. Using an empty passage loader.") passages = SimplePassageLoader() - print(f"INFO: Loaded {len(passages)} passages.") + logger.info(f"Loaded {len(passages)} passages.") def client_warmup(zmq_port): """Perform client-side warmup for DiskANN server""" @@ -365,7 +374,7 @@ def create_embedding_server_thread( def process_batch_pytorch(texts_batch, ids_batch, missing_ids): """处理文本批次""" batch_size = len(texts_batch) - print(f"INFO: Processing batch of size {batch_size}") + logger.info(f"Processing batch of size {batch_size}") tokenize_timer = DeviceTimer("tokenization (batch)", device) to_device_timer = DeviceTimer("transfer to device (batch)", device) diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index f60085a..db302fd 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -18,10 +18,19 @@ import json from pathlib import Path from typing import Dict, Any, Optional, Union import sys +import logging RED = "\033[91m" RESET = "\033[0m" +# Set up logging based on environment variable +LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper() +logging.basicConfig( + level=getattr(logging, LOG_LEVEL, logging.INFO), + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + def is_similarity_metric(): """ @@ -36,6 +45,23 @@ import torch from torch import Tensor import torch.nn.functional as F +# Timing utilities +@contextmanager +def timer(name: str, sync_cuda: bool = True): + """Context manager for timing operations with optional CUDA sync""" + start_time = time.time() + if sync_cuda and torch.cuda.is_available(): + torch.cuda.synchronize() + try: + yield + finally: + if sync_cuda and torch.cuda.is_available(): + torch.cuda.synchronize() + elif sync_cuda and torch.backends.mps.is_available(): + torch.mps.synchronize() + elapsed = time.time() - start_time + logger.info(f"⏱️ {name}: {elapsed:.4f}s") + def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) @@ -120,13 +146,13 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: if passage_data and passage_data.get("text"): return {"text": passage_data["text"]} else: - print(f"DEBUG: Empty text for ID {int_id} -> {string_id}") + logger.debug(f"Empty text for ID {int_id} -> {string_id}") return {"text": ""} else: - print(f"DEBUG: ID {int_id} not found in label_map") + logger.debug(f"ID {int_id} not found in label_map") return {"text": ""} except Exception as e: - print(f"DEBUG: Exception getting passage {passage_id}: {e}") + logger.debug(f"Exception getting passage {passage_id}: {e}") return {"text": ""} def __len__(self) -> int: @@ -184,8 +210,21 @@ def create_hnsw_embedding_server( tokenizer = None # MLX handles tokenization separately else: # sentence-transformers print(f"Loading tokenizer for {model_name}...") - tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - print(f"Tokenizer loaded successfully!") + # Optimized tokenizer loading: try local first, then fallback + try: + tokenizer = AutoTokenizer.from_pretrained( + model_name, + use_fast=True, # Use fast tokenizer (better runtime perf) + local_files_only=True # Avoid network delays + ) + print(f"Tokenizer loaded successfully! (local + fast)") + except Exception as e: + print(f"Local tokenizer failed ({e}), trying network download...") + tokenizer = AutoTokenizer.from_pretrained( + model_name, + use_fast=True # Use fast tokenizer + ) + print(f"Tokenizer loaded successfully! (network)") # Device setup mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() @@ -217,9 +256,47 @@ def create_hnsw_embedding_server( print("OpenAI API mode - no local model loading required") model = None else: - # Use standard transformers for sentence-transformers models - model = AutoModel.from_pretrained(model_name).to(device).eval() - print(f"Model {model_name} loaded successfully!") + # Use optimized transformers loading for sentence-transformers models + print(f"Loading model with optimizations...") + try: + # Ultra-fast loading: preload config + fast_init + from transformers import AutoConfig + config = AutoConfig.from_pretrained(model_name, local_files_only=True) + model = AutoModel.from_pretrained( + model_name, + config=config, + torch_dtype=torch.float16, # Half precision for speed + low_cpu_mem_usage=True, # Reduce memory peaks + local_files_only=True, # Avoid network delays + _fast_init=True # Skip weight init checks + ).to(device).eval() + print(f"Model {model_name} loaded successfully! (ultra-fast)") + except Exception as e: + print(f"Ultra-fast loading failed ({e}), trying optimized...") + try: + # Fallback: regular optimized loading + model = AutoModel.from_pretrained( + model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + local_files_only=True + ).to(device).eval() + print(f"Model {model_name} loaded successfully! (optimized)") + except Exception as e2: + print(f"Optimized loading failed ({e2}), trying network...") + try: + # Fallback: optimized network loading + model = AutoModel.from_pretrained( + model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True + ).to(device).eval() + print(f"Model {model_name} loaded successfully! (network + optimized)") + except Exception as e3: + print(f"All optimized methods failed ({e3}), using standard...") + # Final fallback: standard loading + model = AutoModel.from_pretrained(model_name).to(device).eval() + print(f"Model {model_name} loaded successfully! (standard)") # Check port availability import socket @@ -370,8 +447,9 @@ def create_hnsw_embedding_server( if embedding_mode == "mlx": return _process_batch_mlx(texts_batch, ids_batch, missing_ids) elif embedding_mode == "openai": - from leann.api import compute_embeddings_openai - return compute_embeddings_openai(texts_batch, model_name) + with timer("OpenAI API call", sync_cuda=False): + from leann.api import compute_embeddings_openai + return compute_embeddings_openai(texts_batch, model_name) _is_e5_model = "e5" in model_name.lower() _is_bge_model = "bge" in model_name.lower() @@ -417,44 +495,46 @@ def create_hnsw_embedding_server( enc = {k: v.to(device) for k, v in encoded_batch.items()} with torch.no_grad(): - with embed_timer.timing(): - out = model(enc["input_ids"], enc["attention_mask"]) + with timer("Model forward pass"): + with embed_timer.timing(): + out = model(enc["input_ids"], enc["attention_mask"]) - with pool_timer.timing(): - if _is_bge_model: - pooled_embeddings = out.last_hidden_state[:, 0] - elif not hasattr(out, "last_hidden_state"): - if isinstance(out, torch.Tensor) and len(out.shape) == 2: - pooled_embeddings = out + with timer("Pooling"): + with pool_timer.timing(): + if _is_bge_model: + pooled_embeddings = out.last_hidden_state[:, 0] + elif not hasattr(out, "last_hidden_state"): + if isinstance(out, torch.Tensor) and len(out.shape) == 2: + pooled_embeddings = out + else: + print( + f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}" + ) + hidden_dim = getattr( + model.config, "hidden_size", 384 if _is_e5_model else 768 + ) + pooled_embeddings = torch.zeros( + (batch_size, hidden_dim), + device=device, + dtype=enc["input_ids"].dtype + if hasattr(enc["input_ids"], "dtype") + else torch.float32, + ) + elif _is_e5_model: + pooled_embeddings = e5_average_pool( + out.last_hidden_state, enc["attention_mask"] + ) else: - print( - f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}" + hidden_states = out.last_hidden_state + mask_expanded = ( + enc["attention_mask"] + .unsqueeze(-1) + .expand(hidden_states.size()) + .float() ) - hidden_dim = getattr( - model.config, "hidden_size", 384 if _is_e5_model else 768 - ) - pooled_embeddings = torch.zeros( - (batch_size, hidden_dim), - device=device, - dtype=enc["input_ids"].dtype - if hasattr(enc["input_ids"], "dtype") - else torch.float32, - ) - elif _is_e5_model: - pooled_embeddings = e5_average_pool( - out.last_hidden_state, enc["attention_mask"] - ) - else: - hidden_states = out.last_hidden_state - mask_expanded = ( - enc["attention_mask"] - .unsqueeze(-1) - .expand(hidden_states.size()) - .float() - ) - sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) - sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) - pooled_embeddings = sum_embeddings / sum_mask + sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) + sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) + pooled_embeddings = sum_embeddings / sum_mask final_embeddings = pooled_embeddings if _is_e5_model or _is_bge_model: @@ -536,7 +616,7 @@ def create_hnsw_embedding_server( def zmq_server_thread(): """ZMQ server thread""" - nonlocal passages, model, tokenizer, model_name + nonlocal passages, model, tokenizer, model_name, embedding_mode context = zmq.Context() socket = context.socket(zmq.REP) socket.bind(f"tcp://*:{zmq_port}") @@ -556,13 +636,13 @@ def create_hnsw_embedding_server( try: request_payload = msgpack.unpackb(message_bytes) if isinstance(request_payload, list): - print(f"DEBUG: request_payload length: {len(request_payload)}") + logger.debug(f"request_payload length: {len(request_payload)}") for i, item in enumerate(request_payload): print( f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}" ) - # Handle control messages for meta path and model management + # Handle control messages for meta path and model management FIRST if isinstance(request_payload, list) and len(request_payload) >= 1: if request_payload[0] == "__QUERY_META_PATH__": # Return the current meta path being used by the server @@ -617,19 +697,61 @@ def create_hnsw_embedding_server( ) # Clean up old model to free memory - print("INFO: Releasing old model from memory...") + logger.info("Releasing old model from memory...") old_model = model old_tokenizer = tokenizer - # Load new tokenizer first + # Load new tokenizer first (optimized) print(f"Loading new tokenizer for {new_model_name}...") - tokenizer = AutoTokenizer.from_pretrained( - new_model_name, use_fast=True - ) + try: + tokenizer = AutoTokenizer.from_pretrained( + new_model_name, + use_fast=True, + local_files_only=True + ) + print(f"New tokenizer loaded! (local + fast)") + except: + tokenizer = AutoTokenizer.from_pretrained( + new_model_name, + use_fast=True + ) + print(f"New tokenizer loaded! (network + fast)") - # Load new model + # Load new model (optimized) print(f"Loading new model {new_model_name}...") - model = AutoModel.from_pretrained(new_model_name) + try: + # Ultra-fast model switching + from transformers import AutoConfig + config = AutoConfig.from_pretrained(new_model_name, local_files_only=True) + model = AutoModel.from_pretrained( + new_model_name, + config=config, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + local_files_only=True, + _fast_init=True + ) + print(f"New model loaded! (ultra-fast)") + except: + try: + model = AutoModel.from_pretrained( + new_model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + local_files_only=True + ) + print(f"New model loaded! (optimized)") + except: + try: + model = AutoModel.from_pretrained( + new_model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True + ) + print(f"New model loaded! (network + optimized)") + except: + model = AutoModel.from_pretrained(new_model_name) + print(f"New model loaded! (standard)") model.to(device) model.eval() @@ -640,19 +762,27 @@ def create_hnsw_embedding_server( # Clear GPU cache if available if device.type == "cuda": torch.cuda.empty_cache() - print("INFO: Cleared CUDA cache") + logger.info("Cleared CUDA cache") elif device.type == "mps": torch.mps.empty_cache() - print("INFO: Cleared MPS cache") + logger.info("Cleared MPS cache") # Update model name model_name = new_model_name + + # Re-detect embedding mode based on new model name + if model_name.startswith("text-embedding-"): + embedding_mode = "openai" + logger.info(f"Auto-detected embedding mode: openai for {model_name}") + else: + embedding_mode = "sentence-transformers" + logger.info(f"Auto-detected embedding mode: sentence-transformers for {model_name}") # Force garbage collection import gc gc.collect() - print("INFO: Memory cleanup completed") + logger.info("Memory cleanup completed") response = ["SUCCESS"] print( @@ -664,6 +794,32 @@ def create_hnsw_embedding_server( socket.send(msgpack.packb(response)) continue + # Handle direct text embedding request (for OpenAI and sentence-transformers) + if isinstance(request_payload, list) and len(request_payload) > 0: + # Check if this is a direct text request (list of strings) and NOT a control message + if (all(isinstance(item, str) for item in request_payload) and + not request_payload[0].startswith("__")): + logger.info(f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode") + + try: + if embedding_mode == "openai": + from leann.api import compute_embeddings_openai + embeddings = compute_embeddings_openai(request_payload, model_name) + else: + # sentence-transformers mode - compute directly + with timer(f"Direct text embedding ({len(request_payload)} texts)"): + embeddings = process_batch(request_payload, [], []) + + response = embeddings.tolist() + socket.send(msgpack.packb(response)) + e2e_end = time.time() + logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") + continue + except Exception as e: + logger.error(f"ERROR: Failed to compute {embedding_mode} embeddings: {e}") + socket.send(msgpack.packb([])) + continue + # Handle distance calculation requests if ( isinstance(request_payload, list) @@ -674,7 +830,7 @@ def create_hnsw_embedding_server( node_ids = request_payload[0] query_vector = np.array(request_payload[1], dtype=np.float32) - print("DEBUG: Distance calculation request received") + logger.debug("Distance calculation request received") print(f" Node IDs: {node_ids}") print(f" Query vector dim: {len(query_vector)}") print(f" Passages loaded: {len(passages)}") @@ -684,7 +840,7 @@ def create_hnsw_embedding_server( missing_ids = [] with lookup_timer.timing(): for nid in node_ids: - print(f"DEBUG: Looking up passage ID {nid}") + logger.debug(f"Looking up passage ID {nid}") try: txtinfo = passages[nid] if txtinfo is None: @@ -804,29 +960,11 @@ def create_hnsw_embedding_server( elif device.type == "mps": torch.mps.synchronize() e2e_end = time.time() - print( - f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds" + logger.info( + f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s" ) continue - # Handle direct text embedding request (for OpenAI mode) - if embedding_mode == "openai" and isinstance(request_payload, list) and len(request_payload) > 0: - # Check if this is a direct text request (list of strings) - if all(isinstance(item, str) for item in request_payload): - print(f"Processing direct text embedding request for {len(request_payload)} texts") - - try: - from leann.api import compute_embeddings_openai - embeddings = compute_embeddings_openai(request_payload, model_name) - response = embeddings.tolist() - socket.send(msgpack.packb(response)) - e2e_end = time.time() - print(f"Text embedding E2E time: {e2e_end - e2e_start:.6f} seconds") - continue - except Exception as e: - print(f"ERROR: Failed to compute OpenAI embeddings: {e}") - socket.send(msgpack.packb([])) - continue # Standard embedding request (passage ID lookup) if ( @@ -945,10 +1083,10 @@ def create_hnsw_embedding_server( elif device.type == "mps": torch.mps.synchronize() e2e_end = time.time() - print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds") + logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") except zmq.Again: - print("ZMQ socket timeout, continuing to listen") + logger.debug("ZMQ socket timeout, continuing to listen") continue except Exception as e: print(f"Error in ZMQ server loop: {e}") diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 2042ac8..ec30bad 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -51,7 +51,63 @@ def compute_embeddings( def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) -> np.ndarray: - """Computes embeddings using sentence-transformers library.""" + """Computes embeddings using sentence-transformers via embedding server.""" + print( + f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..." + ) + + # Use embedding server for sentence-transformers too + # This avoids loading the model twice (once in API, once in server) + try: + # Import ZMQ client functionality and server manager + import zmq + import msgpack + import numpy as np + from .embedding_server_manager import EmbeddingServerManager + + # Ensure embedding server is running + port = 5557 + server_manager = EmbeddingServerManager(backend_module_name="leann_backend_hnsw.hnsw_embedding_server") + + server_started = server_manager.start_server( + port=port, + model_name=model_name, + embedding_mode="sentence-transformers", + enable_warmup=False, + ) + + if not server_started: + raise RuntimeError(f"Failed to start embedding server on port {port}") + + # Connect to embedding server + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect(f"tcp://localhost:{port}") + + # Send chunks to server for embedding computation + request = chunks + socket.send(msgpack.packb(request)) + + # Receive embeddings from server + response = socket.recv() + embeddings_list = msgpack.unpackb(response) + + # Convert back to numpy array + embeddings = np.array(embeddings_list, dtype=np.float32) + + socket.close() + context.term() + + return embeddings + + except Exception as e: + # Fallback to direct sentence-transformers if server connection fails + print(f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}") + return _compute_embeddings_sentence_transformers_direct(chunks, model_name) + + +def _compute_embeddings_sentence_transformers_direct(chunks: List[str], model_name: str) -> np.ndarray: + """Direct sentence-transformers computation (fallback).""" try: from sentence_transformers import SentenceTransformer except ImportError as e: @@ -64,7 +120,7 @@ def compute_embeddings_sentence_transformers(chunks: List[str], model_name: str) model = model.half() print( - f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'..." + f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..." ) # use acclerater GPU or MAC GPU diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 2a5e302..303adac 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -200,7 +200,27 @@ class EmbeddingServerManager: # Check model compatibility model_matches = _check_server_model(self.server_port, model_name) - if not model_matches: + if model_matches: + print( + f"✅ Existing server already using correct model: {model_name}" + ) + + # Still check meta path if provided + passages_file = kwargs.get("passages_file") + if passages_file and str(passages_file).endswith( + ".meta.json" + ): + meta_matches = _check_server_meta_path( + self.server_port, str(passages_file) + ) + if not meta_matches: + print("⚠️ Updating meta path to: {passages_file}") + _update_server_meta_path( + self.server_port, str(passages_file) + ) + + return True + else: print( f"⚠️ Existing server has different model. Attempting to update to: {model_name}" ) @@ -230,11 +250,6 @@ class EmbeddingServerManager: ) return True - else: - print( - f"✅ Existing server already using correct model: {model_name}" - ) - return True else: # Server process exists but port not responding - restart print("⚠️ Server process exists but not responding. Restarting...") @@ -254,7 +269,11 @@ class EmbeddingServerManager: # Check model compatibility first model_matches = _check_server_model(port, model_name) - if not model_matches: + if model_matches: + print( + f"✅ Existing server on port {port} is using correct model: {model_name}" + ) + else: print( f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}" ) @@ -263,10 +282,6 @@ class EmbeddingServerManager: f"❌ Failed to update server model to {model_name}. Consider using a different port." ) print(f"✅ Successfully updated server model to: {model_name}") - else: - print( - f"✅ Existing server on port {port} is using correct model: {model_name}" - ) # Check meta path compatibility if provided if passages_file and str(passages_file).endswith(".meta.json"):