#!/usr/bin/env python3 """ Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern """ import pickle import argparse import time import json from typing import Dict, Any, Optional, Union from transformers import AutoTokenizer, AutoModel import os from contextlib import contextmanager import zmq import numpy as np import msgpack from pathlib import Path import logging RED = "\033[91m" # Set up logging based on environment variable LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL, logging.INFO), format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) RESET = "\033[0m" # --- New Passage Loader from HNSW backend --- class SimplePassageLoader: """ Simple passage loader that replaces config.py dependencies """ def __init__(self, passages_data: Optional[Dict[str, Any]] = None): self.passages_data = passages_data or {} self._meta_path = '' def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: """Get passage by ID""" str_id = str(passage_id) if str_id in self.passages_data: return {"text": self.passages_data[str_id]} else: # Return empty text for missing passages return {"text": ""} def __len__(self) -> int: return len(self.passages_data) def keys(self): return self.passages_data.keys() def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: """ Load passages using metadata file with PassageManager for lazy loading """ # Load metadata to get passage sources with open(meta_file, 'r') as f: meta = json.load(f) # Import PassageManager dynamically to avoid circular imports import sys from pathlib import Path # Find the leann package directory relative to this file current_dir = Path(__file__).parent leann_core_path = current_dir.parent.parent / "leann-core" / "src" sys.path.insert(0, str(leann_core_path)) try: from leann.api import PassageManager passage_manager = PassageManager(meta['passage_sources']) finally: sys.path.pop(0) print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages") class LazyPassageLoader(SimplePassageLoader): def __init__(self, passage_manager): self.passage_manager = passage_manager # Initialize parent with empty data super().__init__({}) def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: """Get passage by ID with lazy loading""" try: int_id = int(passage_id) string_id = str(int_id) passage_data = self.passage_manager.get_passage(string_id) if passage_data and passage_data.get("text"): return {"text": passage_data["text"]} else: raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}") except Exception as e: raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}") def __len__(self) -> int: return len(self.passage_manager.global_offset_map) def keys(self): return self.passage_manager.global_offset_map.keys() loader = LazyPassageLoader(passage_manager) loader._meta_path = meta_file return loader def load_passages_from_file(passages_file: str) -> SimplePassageLoader: """ Load passages from a JSONL file with label map support Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line) """ if not os.path.exists(passages_file): raise FileNotFoundError(f"Passages file {passages_file} not found.") if not passages_file.endswith('.jsonl'): raise ValueError(f"Expected .jsonl file format, got: {passages_file}") # Load passages directly by their sequential IDs passages_data = {} with open(passages_file, 'r', encoding='utf-8') as f: for line in f: if line.strip(): passage = json.loads(line) passages_data[passage['id']] = passage['text'] print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}") return SimplePassageLoader(passages_data) def create_embedding_server_thread( zmq_port=5555, model_name="sentence-transformers/all-mpnet-base-v2", max_batch_size=128, passages_file: Optional[str] = None, embedding_mode: str = "sentence-transformers", enable_warmup: bool = False, ): """ Create and run embedding server in the current thread This function is designed to be called in a separate thread """ logger.info(f"Initializing embedding server thread on port {zmq_port}") try: # Check if port is already occupied import socket def check_port(port): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(('localhost', port)) == 0 if check_port(zmq_port): print(f"{RED}Port {zmq_port} is already in use{RESET}") return # Auto-detect mode based on model name if not explicitly set if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"): embedding_mode = "openai" if embedding_mode == "mlx": from leann.api import compute_embeddings_mlx import torch logger.info("Using MLX for embeddings") # Set device to CPU for compatibility with DeviceTimer class device = torch.device("cpu") cuda_available = False mps_available = False elif embedding_mode == "openai": from leann.api import compute_embeddings_openai import torch logger.info("Using OpenAI API for embeddings") # Set device to CPU for compatibility with DeviceTimer class device = torch.device("cpu") cuda_available = False mps_available = False elif embedding_mode == "sentence-transformers": # Initialize model tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) import torch # Select device mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() cuda_available = torch.cuda.is_available() if cuda_available: device = torch.device("cuda") logger.info("Using CUDA device") elif mps_available: device = torch.device("mps") logger.info("Using MPS device (Apple Silicon)") else: device = torch.device("cpu") logger.info("Using CPU device") # Load model logger.info(f"Loading model {model_name}") model = AutoModel.from_pretrained(model_name).to(device).eval() # Optimize model if cuda_available or mps_available: try: model = model.half() model = torch.compile(model) logger.info(f"Using FP16 precision with model: {model_name}") except Exception as e: print(f"WARNING: Model optimization failed: {e}") else: raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai") # Load passages from file if provided if passages_file and os.path.exists(passages_file): # Check if it's a metadata file or a single passages file if passages_file.endswith('.meta.json'): passages = load_passages_from_metadata(passages_file) else: # Try to find metadata file in same directory passages_dir = Path(passages_file).parent meta_files = list(passages_dir.glob("*.meta.json")) if meta_files: print(f"Found metadata file: {meta_files[0]}, using lazy loading") passages = load_passages_from_metadata(str(meta_files[0])) else: # Fallback to original single file loading (will cause warnings) print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)") passages = load_passages_from_file(passages_file) else: print("WARNING: No passages file provided or file not found. Using an empty passage loader.") passages = SimplePassageLoader() logger.info(f"Loaded {len(passages)} passages.") def client_warmup(zmq_port): """Perform client-side warmup for DiskANN server""" time.sleep(2) print(f"Performing client-side warmup with model {model_name}...") # Get actual passage IDs from the loaded passages sample_ids = [] if hasattr(passages, 'keys') and len(passages) > 0: available_ids = list(passages.keys()) # Take up to 5 actual IDs, but at least 1 sample_ids = available_ids[:min(5, len(available_ids))] print(f"Using actual passage IDs for warmup: {sample_ids}") else: print("No passages available for warmup, skipping warmup...") return try: context = zmq.Context() socket = context.socket(zmq.REQ) socket.connect(f"tcp://localhost:{zmq_port}") socket.setsockopt(zmq.RCVTIMEO, 30000) socket.setsockopt(zmq.SNDTIMEO, 30000) try: ids_to_send = [int(x) for x in sample_ids] except ValueError: print("Warning: Could not convert sample IDs to integers, skipping warmup") return if not ids_to_send: print("Skipping warmup send.") return # Use protobuf format for warmup from . import embedding_pb2 req_proto = embedding_pb2.NodeEmbeddingRequest() req_proto.node_ids.extend(ids_to_send) request_bytes = req_proto.SerializeToString() for i in range(3): print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...") socket.send(request_bytes) response_bytes = socket.recv() resp_proto = embedding_pb2.NodeEmbeddingResponse() resp_proto.ParseFromString(response_bytes) embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0 print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings") time.sleep(0.1) print("Client-side Protobuf ZMQ warmup complete") socket.close() context.term() except Exception as e: print(f"Error during Protobuf ZMQ warmup: {e}") class DeviceTimer: """Device timer""" def __init__(self, name="", device=device): self.name = name self.device = device self.start_time = 0 self.end_time = 0 if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): self.start_event = torch.cuda.Event(enable_timing=True) self.end_event = torch.cuda.Event(enable_timing=True) else: self.start_event = None self.end_event = None @contextmanager def timing(self): self.start() yield self.end() def start(self): if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): torch.cuda.synchronize() self.start_event.record() else: if embedding_mode == "sentence-transformers" and self.device.type == "mps": torch.mps.synchronize() self.start_time = time.time() def end(self): if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): self.end_event.record() torch.cuda.synchronize() else: if embedding_mode == "sentence-transformers" and self.device.type == "mps": torch.mps.synchronize() self.end_time = time.time() def elapsed_time(self): if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): return self.start_event.elapsed_time(self.end_event) / 1000.0 else: return self.end_time - self.start_time def print_elapsed(self): elapsed = self.elapsed_time() print(f"[{self.name}] Elapsed time: {elapsed:.3f}s") def process_batch_pytorch(texts_batch, ids_batch, missing_ids): """Process text batch""" if not texts_batch: return np.array([]) # Filter out empty texts and their corresponding IDs valid_texts = [] valid_ids = [] for i, text in enumerate(texts_batch): if text.strip(): # Only include non-empty texts valid_texts.append(text) valid_ids.append(ids_batch[i]) if not valid_texts: print("WARNING: No valid texts in batch") return np.array([]) # Tokenize token_timer = DeviceTimer("tokenization") with token_timer.timing(): inputs = tokenizer( valid_texts, padding=True, truncation=True, max_length=512, return_tensors="pt" ).to(device) # Compute embeddings embed_timer = DeviceTimer("embedding computation") with embed_timer.timing(): with torch.no_grad(): outputs = model(**inputs) hidden_states = outputs.last_hidden_state # Mean pooling attention_mask = inputs['attention_mask'] mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) batch_embeddings = sum_embeddings / sum_mask embed_timer.print_elapsed() return batch_embeddings.cpu().numpy() # ZMQ server main loop - modified to use REP socket context = zmq.Context() socket = context.socket(zmq.ROUTER) # Changed to REP socket socket.bind(f"tcp://127.0.0.1:{zmq_port}") print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}") # Set timeouts socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout from . import embedding_pb2 print(f"INFO: Embedding server ready to serve requests") # Start warmup thread if enabled if enable_warmup and len(passages) > 0: import threading print(f"Warmup enabled: starting warmup thread") warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,)) warmup_thread.daemon = True warmup_thread.start() else: print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})") while True: try: parts = socket.recv_multipart() # --- Restore robust message format detection --- # Must check parts length to avoid IndexError if len(parts) >= 3: identity = parts[0] # empty = parts[1] # We usually don't care about the middle empty frame message = parts[2] elif len(parts) == 2: # Can also handle cases without empty frame identity = parts[0] message = parts[1] else: # If received message format is wrong, print warning and ignore it instead of crashing print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.") continue print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes") # Handle control messages (MessagePack format) try: request_payload = msgpack.unpackb(message) if isinstance(request_payload, list) and len(request_payload) >= 1: if request_payload[0] == "__QUERY_META_PATH__": # Return the current meta path being used by the server current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else '' response = [current_meta_path] socket.send_multipart([identity, b'', msgpack.packb(response)]) continue elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2: # Update the server's meta path and reload passages new_meta_path = request_payload[1] try: print(f"INFO: Updating server meta path to: {new_meta_path}") # Reload passages from the new meta file passages = load_passages_from_metadata(new_meta_path) # Store the meta path for future queries passages._meta_path = new_meta_path response = ["SUCCESS"] print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages") except Exception as e: print(f"ERROR: Failed to update meta path: {e}") response = ["FAILED", str(e)] socket.send_multipart([identity, b'', msgpack.packb(response)]) continue elif request_payload[0] == "__QUERY_MODEL__": # Return the current model being used by the server response = [model_name] socket.send_multipart([identity, b'', msgpack.packb(response)]) continue elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2: # Update the server's embedding model new_model_name = request_payload[1] try: print(f"INFO: Updating server model from {model_name} to: {new_model_name}") # Clean up old model to free memory if not use_mlx: print("INFO: Releasing old model from memory...") old_model = model old_tokenizer = tokenizer # Load new tokenizer first print(f"Loading new tokenizer for {new_model_name}...") tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True) # Load new model print(f"Loading new model {new_model_name}...") model = AutoModel.from_pretrained(new_model_name).to(device).eval() # Optimize new model if cuda_available or mps_available: try: model = model.half() model = torch.compile(model) print(f"INFO: Using FP16 precision with model: {new_model_name}") except Exception as e: print(f"WARNING: Model optimization failed: {e}") # Now safely delete old model after new one is loaded del old_model del old_tokenizer # Clear GPU cache if available if device.type == "cuda": torch.cuda.empty_cache() print("INFO: Cleared CUDA cache") elif device.type == "mps": torch.mps.empty_cache() print("INFO: Cleared MPS cache") # Force garbage collection import gc gc.collect() print("INFO: Memory cleanup completed") # Update model name model_name = new_model_name response = ["SUCCESS"] print(f"INFO: Successfully updated model to: {new_model_name}") except Exception as e: print(f"ERROR: Failed to update model: {e}") response = ["FAILED", str(e)] socket.send_multipart([identity, b'', msgpack.packb(response)]) continue except: # Not a control message, continue with normal protobuf processing pass e2e_start = time.time() lookup_timer = DeviceTimer("text lookup") # Parse request req_proto = embedding_pb2.NodeEmbeddingRequest() req_proto.ParseFromString(message) node_ids = req_proto.node_ids print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}") # Add debug information if len(node_ids) > 0: print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}") # Look up texts texts = [] missing_ids = [] with lookup_timer.timing(): for nid in node_ids: txtinfo = passages[nid] txt = txtinfo["text"] if txt: texts.append(txt) else: # If text is empty, we still need a placeholder for batch processing, # but record its ID as missing texts.append("") missing_ids.append(nid) lookup_timer.print_elapsed() if missing_ids: print(f"WARNING: Missing passages for IDs: {missing_ids}") # Process batch total_size = len(texts) print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}") all_embeddings = [] if total_size > max_batch_size: print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}") for i in range(0, total_size, max_batch_size): end_idx = min(i + max_batch_size, total_size) print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}") chunk_texts = texts[i:end_idx] chunk_ids = node_ids[i:end_idx] if embedding_mode == "mlx": embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16) elif embedding_mode == "openai": embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name) else: # sentence-transformers embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids) all_embeddings.append(embeddings_chunk) if embedding_mode == "sentence-transformers": if cuda_available: torch.cuda.empty_cache() elif device.type == "mps": torch.mps.empty_cache() hidden = np.vstack(all_embeddings) print(f"INFO: Combined embeddings shape: {hidden.shape}") else: if embedding_mode == "mlx": hidden = compute_embeddings_mlx(texts, model_name, batch_size=16) elif embedding_mode == "openai": hidden = compute_embeddings_openai(texts, model_name) else: # sentence-transformers hidden = process_batch_pytorch(texts, node_ids, missing_ids) # Serialize response ser_start = time.time() resp_proto = embedding_pb2.NodeEmbeddingResponse() hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32) resp_proto.embeddings_data = hidden_contiguous.tobytes() resp_proto.dimensions.append(hidden_contiguous.shape[0]) resp_proto.dimensions.append(hidden_contiguous.shape[1]) resp_proto.missing_ids.extend(missing_ids) response_data = resp_proto.SerializeToString() # REP socket sends a single response socket.send_multipart([identity, b'', response_data]) ser_end = time.time() print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds") if embedding_mode == "sentence-transformers": if device.type == "cuda": torch.cuda.synchronize() elif device.type == "mps": torch.mps.synchronize() e2e_end = time.time() print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds") except zmq.Again: print("INFO: ZMQ socket timeout, continuing to listen") continue except Exception as e: print(f"ERROR: Error in ZMQ server: {e}") try: # Send empty response to maintain REQ-REP state empty_resp = embedding_pb2.NodeEmbeddingResponse() socket.send(empty_resp.SerializeToString()) except: # If sending fails, recreate socket socket.close() socket = context.socket(zmq.REP) socket.bind(f"tcp://127.0.0.1:{zmq_port}") socket.setsockopt(zmq.RCVTIMEO, 5000) socket.setsockopt(zmq.SNDTIMEO, 300000) print("INFO: ZMQ socket recreated after error") except Exception as e: print(f"ERROR: Failed to start embedding server: {e}") raise def create_embedding_server( domain="demo", load_passages=True, load_embeddings=False, use_fp16=True, use_int8=False, use_cuda_graphs=False, zmq_port=5555, max_batch_size=128, lazy_load_passages=False, model_name="sentence-transformers/all-mpnet-base-v2", passages_file: Optional[str] = None, embedding_mode: str = "sentence-transformers", enable_warmup: bool = False, ): """ 原有的 create_embedding_server 函数保持不变 这个是阻塞版本,用于直接运行 """ create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Embedding service") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--domain", type=str, default="demo", help="Domain name") parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping") parser.add_argument("--load-passages", action="store_true", default=True) parser.add_argument("--load-embeddings", action="store_true", default=False) parser.add_argument("--use-fp16", action="store_true", default=False) parser.add_argument("--use-int8", action="store_true", default=False) parser.add_argument("--use-cuda-graphs", action="store_true", default=False) parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting") parser.add_argument("--lazy-load-passages", action="store_true", default=True) parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", help="Embedding model name") parser.add_argument("--embedding-mode", type=str, default="sentence-transformers", choices=["sentence-transformers", "mlx", "openai"], help="Embedding backend mode") parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)") parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start") args = parser.parse_args() # Handle backward compatibility with use_mlx embedding_mode = args.embedding_mode if args.use_mlx: embedding_mode = "mlx" create_embedding_server( domain=args.domain, load_passages=args.load_passages, load_embeddings=args.load_embeddings, use_fp16=args.use_fp16, use_int8=args.use_int8, use_cuda_graphs=args.use_cuda_graphs, zmq_port=args.zmq_port, max_batch_size=args.max_batch_size, lazy_load_passages=args.lazy_load_passages, model_name=args.model_name, passages_file=args.passages_file, embedding_mode=embedding_mode, enable_warmup=not args.disable_warmup, )