diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index a28a744..b4df2f7 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -1,10 +1,13 @@ import numpy as np import os import struct +import sys from pathlib import Path -from typing import Dict, Any, List, Literal +from typing import Dict, Any, List, Literal, Optional import contextlib +import logging + from leann.searcher_base import BaseSearcher from leann.registry import register_backend from leann.interface import ( @@ -13,6 +16,46 @@ from leann.interface import ( LeannBackendSearcherInterface, ) +logger = logging.getLogger(__name__) + + +@contextlib.contextmanager +def suppress_cpp_output_if_needed(): + """Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL""" + log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() + + # Only suppress if log level is WARNING or higher (ERROR, CRITICAL) + should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"] + + if not should_suppress: + # Don't suppress, just yield + yield + return + + # Save original file descriptors + stdout_fd = sys.stdout.fileno() + stderr_fd = sys.stderr.fileno() + + # Save original stdout/stderr + stdout_dup = os.dup(stdout_fd) + stderr_dup = os.dup(stderr_fd) + + try: + # Redirect to /dev/null + devnull = os.open(os.devnull, os.O_WRONLY) + os.dup2(devnull, stdout_fd) + os.dup2(devnull, stderr_fd) + os.close(devnull) + + yield + + finally: + # Restore original file descriptors + os.dup2(stdout_dup, stdout_fd) + os.dup2(stderr_dup, stderr_fd) + os.close(stdout_dup) + os.close(stderr_dup) + def _get_diskann_metrics(): from . import _diskannpy as diskannpy # type: ignore @@ -64,6 +107,7 @@ class DiskannBuilder(LeannBackendBuilderInterface): index_dir.mkdir(parents=True, exist_ok=True) if data.dtype != np.float32: + logger.warning(f"Converting data to float32, shape: {data.shape}") data = data.astype(np.float32) data_filename = f"{index_prefix}_data.bin" @@ -74,7 +118,9 @@ class DiskannBuilder(LeannBackendBuilderInterface): build_kwargs.get("distance_metric", "mips").lower() ) if metric_enum is None: - raise ValueError("Unsupported distance_metric.") + raise ValueError( + f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'." + ) try: from . import _diskannpy as diskannpy # type: ignore @@ -96,36 +142,40 @@ class DiskannBuilder(LeannBackendBuilderInterface): temp_data_file = index_dir / data_filename if temp_data_file.exists(): os.remove(temp_data_file) + logger.debug(f"Cleaned up temporary data file: {temp_data_file}") class DiskannSearcher(BaseSearcher): def __init__(self, index_path: str, **kwargs): super().__init__( index_path, - backend_module_name="leann_backend_diskann.embedding_server", + backend_module_name="leann_backend_diskann.diskann_embedding_server", **kwargs, ) - from . import _diskannpy as diskannpy # type: ignore - distance_metric = kwargs.get("distance_metric", "mips").lower() - metric_enum = _get_diskann_metrics().get(distance_metric) - if metric_enum is None: - raise ValueError(f"Unsupported distance_metric '{distance_metric}'.") + # Initialize DiskANN index with suppressed C++ output based on log level + with suppress_cpp_output_if_needed(): + from . import _diskannpy as diskannpy # type: ignore - self.num_threads = kwargs.get("num_threads", 8) - self.zmq_port = kwargs.get("zmq_port", 6666) + distance_metric = kwargs.get("distance_metric", "mips").lower() + metric_enum = _get_diskann_metrics().get(distance_metric) + if metric_enum is None: + raise ValueError(f"Unsupported distance_metric '{distance_metric}'.") - full_index_prefix = str(self.index_dir / self.index_path.stem) - self._index = diskannpy.StaticDiskFloatIndex( - metric_enum, - full_index_prefix, - self.num_threads, - kwargs.get("num_nodes_to_cache", 0), - 1, - self.zmq_port, - "", - "", - ) + self.num_threads = kwargs.get("num_threads", 8) + + fake_zmq_port = 6666 + full_index_prefix = str(self.index_dir / self.index_path.stem) + self._index = diskannpy.StaticDiskFloatIndex( + metric_enum, + full_index_prefix, + self.num_threads, + kwargs.get("num_nodes_to_cache", 0), + 1, + fake_zmq_port, # Initial port, can be updated at runtime + "", + "", + ) def search( self, @@ -136,7 +186,7 @@ class DiskannSearcher(BaseSearcher): prune_ratio: float = 0.0, recompute_embeddings: bool = False, pruning_strategy: Literal["global", "local", "proportional"] = "global", - zmq_port: int = 5557, + zmq_port: Optional[int] = None, batch_recompute: bool = False, dedup_node_dis: bool = False, **kwargs, @@ -155,7 +205,7 @@ class DiskannSearcher(BaseSearcher): - "global": Use global pruning strategy (default) - "local": Use local pruning strategy - "proportional": Not supported in DiskANN, falls back to global - zmq_port: ZMQ port for embedding server + zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True. batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific) dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific) **kwargs: Additional DiskANN-specific parameters (for legacy compatibility) @@ -163,22 +213,25 @@ class DiskannSearcher(BaseSearcher): Returns: Dict with 'labels' (list of lists) and 'distances' (ndarray) """ + # Handle zmq_port compatibility: DiskANN can now update port at runtime + if recompute_embeddings: + if zmq_port is None: + raise ValueError( + "zmq_port must be provided if recompute_embeddings is True" + ) + current_port = self._index.get_zmq_port() + if zmq_port != current_port: + logger.debug( + f"Updating DiskANN zmq_port from {current_port} to {zmq_port}" + ) + self._index.set_zmq_port(zmq_port) + # DiskANN doesn't support "proportional" strategy if pruning_strategy == "proportional": raise NotImplementedError( "DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead." ) - # Use recompute_embeddings parameter - use_recompute = recompute_embeddings - if use_recompute: - meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" - if not meta_file_path.exists(): - raise RuntimeError( - f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}" - ) - self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) - if query.dtype != np.float32: query = query.astype(np.float32) @@ -188,21 +241,23 @@ class DiskannSearcher(BaseSearcher): else: # "global" use_global_pruning = True - labels, distances = self._index.batch_search( - query, - query.shape[0], - top_k, - complexity, - beam_width, - self.num_threads, - kwargs.get("USE_DEFERRED_FETCH", False), - kwargs.get("skip_search_reorder", False), - use_recompute, - dedup_node_dis, - prune_ratio, - batch_recompute, - use_global_pruning, - ) + # Perform search with suppressed C++ output based on log level + with suppress_cpp_output_if_needed(): + labels, distances = self._index.batch_search( + query, + query.shape[0], + top_k, + complexity, + beam_width, + self.num_threads, + kwargs.get("USE_DEFERRED_FETCH", False), + kwargs.get("skip_search_reorder", False), + recompute_embeddings, + dedup_node_dis, + prune_ratio, + batch_recompute, + use_global_pruning, + ) string_labels = [ [str(int_label) for int_label in batch_labels] for batch_labels in labels diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py new file mode 100644 index 0000000..18bcd09 --- /dev/null +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py @@ -0,0 +1,269 @@ +""" +DiskANN-specific embedding server +""" + +import argparse +import threading +import time +import os +import zmq +import numpy as np +import json +from pathlib import Path +from typing import Optional +import sys +import logging + +# Set up logging based on environment variable +LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() +logger = logging.getLogger(__name__) + +# Force set logger level (don't rely on basicConfig in subprocess) +log_level = getattr(logging, LOG_LEVEL, logging.WARNING) +logger.setLevel(log_level) + +# Ensure we have a handler if none exists +if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.propagate = False + + +def create_diskann_embedding_server( + passages_file: Optional[str] = None, + zmq_port: int = 5555, + model_name: str = "sentence-transformers/all-mpnet-base-v2", + embedding_mode: str = "sentence-transformers", +): + """ + Create and start a ZMQ-based embedding server for DiskANN backend. + Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation. + """ + logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}") + logger.info(f"Using embedding mode: {embedding_mode}") + + # Add leann-core to path for unified embedding computation + current_dir = Path(__file__).parent + leann_core_path = current_dir.parent.parent / "leann-core" / "src" + sys.path.insert(0, str(leann_core_path)) + + try: + from leann.embedding_compute import compute_embeddings + from leann.api import PassageManager + + logger.info("Successfully imported unified embedding computation module") + except ImportError as e: + logger.error(f"Failed to import embedding computation module: {e}") + return + finally: + sys.path.pop(0) + + # Check port availability + import socket + + def check_port(port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + if check_port(zmq_port): + logger.error(f"Port {zmq_port} is already in use") + return + + # Only support metadata file, fail fast for everything else + if not passages_file or not passages_file.endswith(".meta.json"): + raise ValueError("Only metadata files (.meta.json) are supported") + + # Load metadata to get passage sources + with open(passages_file, "r") as f: + meta = json.load(f) + + passages = PassageManager(meta["passage_sources"]) + logger.info( + f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" + ) + + # Import protobuf after ensuring the path is correct + try: + from . import embedding_pb2 + except ImportError as e: + logger.error(f"Failed to import protobuf module: {e}") + return + + def zmq_server_thread(): + """ZMQ server thread using REP socket for universal compatibility""" + context = zmq.Context() + socket = context.socket(zmq.REP) # REP socket for both BaseSearcher and DiskANN C++ REQ clients + socket.bind(f"tcp://*:{zmq_port}") + logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}") + + socket.setsockopt(zmq.RCVTIMEO, 300000) + socket.setsockopt(zmq.SNDTIMEO, 300000) + + while True: + try: + # REP socket receives single-part messages + message = socket.recv() + + # Check for empty messages - REP socket requires response to every request + if len(message) == 0: + logger.debug("Received empty message, sending empty response") + socket.send(b"") # REP socket must respond to every request + continue + + logger.debug(f"Received ZMQ request of size {len(message)} bytes") + logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes + + e2e_start = time.time() + + # Try protobuf first (for DiskANN C++ node_ids requests - primary use case) + texts = [] + node_ids = [] + is_text_request = False + + try: + req_proto = embedding_pb2.NodeEmbeddingRequest() + req_proto.ParseFromString(message) + node_ids = list(req_proto.node_ids) + + if not node_ids: + raise RuntimeError(f"PROTOBUF: Received empty node_ids! Message size: {len(message)}") + + logger.info( + f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}" + ) + except Exception as protobuf_error: + logger.debug(f"Protobuf parsing failed: {protobuf_error}") + # Fallback to msgpack (for BaseSearcher direct text requests) + try: + import msgpack + + request = msgpack.unpackb(message) + # For BaseSearcher compatibility, request is a list of texts directly + if isinstance(request, list) and all( + isinstance(item, str) for item in request + ): + texts = request + is_text_request = True + logger.info( + f"✅ MSGPACK: Direct text request for {len(texts)} texts" + ) + else: + raise ValueError("Not a valid msgpack text request") + except Exception as msgpack_error: + raise RuntimeError( + f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}" + ) + + # Look up texts by node IDs (only if not direct text request) + if not is_text_request: + for nid in node_ids: + try: + passage_data = passages.get_passage(str(nid)) + txt = passage_data["text"] + if not txt: + raise RuntimeError( + f"FATAL: Empty text for passage ID {nid}" + ) + texts.append(txt) + except KeyError as e: + logger.error(f"Passage ID {nid} not found: {e}") + raise e + except Exception as e: + logger.error(f"Exception looking up passage ID {nid}: {e}") + raise + + # Debug logging + logger.debug( + f"Processing {len(texts)} texts" + ) + logger.debug( + f"Text lengths: {[len(t) for t in texts[:5]]}" + ) # Show first 5 + + # Process embeddings using unified computation + embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) + + # Prepare response based on request type + if is_text_request: + # For BaseSearcher compatibility: return msgpack format + import msgpack + + response_data = msgpack.packb(embeddings.tolist()) + else: + # For DiskANN C++ compatibility: return protobuf format + resp_proto = embedding_pb2.NodeEmbeddingResponse() + hidden_contiguous = np.ascontiguousarray( + embeddings, dtype=np.float32 + ) + + # Serialize embeddings data + resp_proto.embeddings_data = hidden_contiguous.tobytes() + resp_proto.dimensions.append(hidden_contiguous.shape[0]) + resp_proto.dimensions.append(hidden_contiguous.shape[1]) + + response_data = resp_proto.SerializeToString() + + # Send response back to the client + socket.send(response_data) + + e2e_end = time.time() + logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") + + except zmq.Again: + logger.debug("ZMQ socket timeout, continuing to listen") + continue + except Exception as e: + logger.error(f"Error in ZMQ server loop: {e}") + import traceback + traceback.print_exc() + raise + + zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) + zmq_thread.start() + logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}") + + # Keep the main thread alive + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("DiskANN Server shutting down...") + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="DiskANN Embedding service") + parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") + parser.add_argument( + "--passages-file", + type=str, + help="Metadata JSON file containing passage sources", + ) + parser.add_argument( + "--model-name", + type=str, + default="sentence-transformers/all-mpnet-base-v2", + help="Embedding model name", + ) + parser.add_argument( + "--embedding-mode", + type=str, + default="sentence-transformers", + choices=["sentence-transformers", "openai", "mlx"], + help="Embedding backend mode", + ) + + args = parser.parse_args() + + # Create and start the DiskANN embedding server + create_diskann_embedding_server( + passages_file=args.passages_file, + zmq_port=args.zmq_port, + model_name=args.model_name, + embedding_mode=args.embedding_mode, + ) diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py deleted file mode 100644 index 04f7f56..0000000 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ /dev/null @@ -1,705 +0,0 @@ -#!/usr/bin/env python3 -""" -Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern -""" - -import pickle -import argparse -import time -import json -from typing import Dict, Any, Optional, Union - -from transformers import AutoTokenizer, AutoModel -import os -from contextlib import contextmanager -import zmq -import numpy as np -import msgpack -from pathlib import Path -import logging - -RED = "\033[91m" - -# Set up logging based on environment variable -LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper() -logging.basicConfig( - level=getattr(logging, LOG_LEVEL, logging.INFO), - format='%(asctime)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) -RESET = "\033[0m" - -# --- New Passage Loader from HNSW backend --- -class SimplePassageLoader: - """ - Simple passage loader that replaces config.py dependencies - """ - def __init__(self, passages_data: Optional[Dict[str, Any]] = None): - self.passages_data = passages_data or {} - self._meta_path = '' - - def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: - """Get passage by ID""" - str_id = str(passage_id) - if str_id in self.passages_data: - return {"text": self.passages_data[str_id]} - else: - # Return empty text for missing passages - return {"text": ""} - - def __len__(self) -> int: - return len(self.passages_data) - - def keys(self): - return self.passages_data.keys() - -def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: - """ - Load passages using metadata file with PassageManager for lazy loading - """ - # Load metadata to get passage sources - with open(meta_file, 'r') as f: - meta = json.load(f) - - # Import PassageManager dynamically to avoid circular imports - import sys - from pathlib import Path - - # Find the leann package directory relative to this file - current_dir = Path(__file__).parent - leann_core_path = current_dir.parent.parent / "leann-core" / "src" - sys.path.insert(0, str(leann_core_path)) - - try: - from leann.api import PassageManager - passage_manager = PassageManager(meta['passage_sources']) - finally: - sys.path.pop(0) - - print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages") - - class LazyPassageLoader(SimplePassageLoader): - def __init__(self, passage_manager): - self.passage_manager = passage_manager - # Initialize parent with empty data - super().__init__({}) - - def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: - """Get passage by ID with lazy loading""" - try: - int_id = int(passage_id) - string_id = str(int_id) - passage_data = self.passage_manager.get_passage(string_id) - if passage_data and passage_data.get("text"): - return {"text": passage_data["text"]} - else: - raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}") - except Exception as e: - raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}") - - def __len__(self) -> int: - return len(self.passage_manager.global_offset_map) - - def keys(self): - return self.passage_manager.global_offset_map.keys() - - loader = LazyPassageLoader(passage_manager) - loader._meta_path = meta_file - return loader - -def load_passages_from_file(passages_file: str) -> SimplePassageLoader: - """ - Load passages from a JSONL file with label map support - Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line) - """ - - if not os.path.exists(passages_file): - raise FileNotFoundError(f"Passages file {passages_file} not found.") - - if not passages_file.endswith('.jsonl'): - raise ValueError(f"Expected .jsonl file format, got: {passages_file}") - - # Load passages directly by their sequential IDs - passages_data = {} - with open(passages_file, 'r', encoding='utf-8') as f: - for line in f: - if line.strip(): - passage = json.loads(line) - passages_data[passage['id']] = passage['text'] - - print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}") - return SimplePassageLoader(passages_data) - -def create_embedding_server_thread( - zmq_port=5555, - model_name="sentence-transformers/all-mpnet-base-v2", - max_batch_size=128, - passages_file: Optional[str] = None, - embedding_mode: str = "sentence-transformers", - enable_warmup: bool = False, -): - """ - Create and run embedding server in the current thread - This function is designed to be called in a separate thread - """ - logger.info(f"Initializing embedding server thread on port {zmq_port}") - - try: - # Check if port is already occupied - import socket - def check_port(port): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 - - if check_port(zmq_port): - print(f"{RED}Port {zmq_port} is already in use{RESET}") - return - - # Auto-detect mode based on model name if not explicitly set - if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"): - embedding_mode = "openai" - - if embedding_mode == "mlx": - from leann.api import compute_embeddings_mlx - import torch - logger.info("Using MLX for embeddings") - # Set device to CPU for compatibility with DeviceTimer class - device = torch.device("cpu") - cuda_available = False - mps_available = False - elif embedding_mode == "openai": - from leann.api import compute_embeddings_openai - import torch - logger.info("Using OpenAI API for embeddings") - # Set device to CPU for compatibility with DeviceTimer class - device = torch.device("cpu") - cuda_available = False - mps_available = False - elif embedding_mode == "sentence-transformers": - # Initialize model - tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - import torch - - # Select device - mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() - cuda_available = torch.cuda.is_available() - - if cuda_available: - device = torch.device("cuda") - logger.info("Using CUDA device") - elif mps_available: - device = torch.device("mps") - logger.info("Using MPS device (Apple Silicon)") - else: - device = torch.device("cpu") - logger.info("Using CPU device") - - # Load model - logger.info(f"Loading model {model_name}") - model = AutoModel.from_pretrained(model_name).to(device).eval() - - # Optimize model - if cuda_available or mps_available: - try: - model = model.half() - model = torch.compile(model) - logger.info(f"Using FP16 precision with model: {model_name}") - except Exception as e: - print(f"WARNING: Model optimization failed: {e}") - else: - raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai") - - # Load passages from file if provided - if passages_file and os.path.exists(passages_file): - # Check if it's a metadata file or a single passages file - if passages_file.endswith('.meta.json'): - passages = load_passages_from_metadata(passages_file) - else: - # Try to find metadata file in same directory - passages_dir = Path(passages_file).parent - meta_files = list(passages_dir.glob("*.meta.json")) - if meta_files: - print(f"Found metadata file: {meta_files[0]}, using lazy loading") - passages = load_passages_from_metadata(str(meta_files[0])) - else: - # Fallback to original single file loading (will cause warnings) - print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)") - passages = load_passages_from_file(passages_file) - else: - print("WARNING: No passages file provided or file not found. Using an empty passage loader.") - passages = SimplePassageLoader() - - logger.info(f"Loaded {len(passages)} passages.") - - def client_warmup(zmq_port): - """Perform client-side warmup for DiskANN server""" - time.sleep(2) - print(f"Performing client-side warmup with model {model_name}...") - - # Get actual passage IDs from the loaded passages - sample_ids = [] - if hasattr(passages, 'keys') and len(passages) > 0: - available_ids = list(passages.keys()) - # Take up to 5 actual IDs, but at least 1 - sample_ids = available_ids[:min(5, len(available_ids))] - print(f"Using actual passage IDs for warmup: {sample_ids}") - else: - print("No passages available for warmup, skipping warmup...") - return - - try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.connect(f"tcp://localhost:{zmq_port}") - socket.setsockopt(zmq.RCVTIMEO, 30000) - socket.setsockopt(zmq.SNDTIMEO, 30000) - - try: - ids_to_send = [int(x) for x in sample_ids] - except ValueError: - print("Warning: Could not convert sample IDs to integers, skipping warmup") - return - - if not ids_to_send: - print("Skipping warmup send.") - return - - # Use protobuf format for warmup - from . import embedding_pb2 - req_proto = embedding_pb2.NodeEmbeddingRequest() - req_proto.node_ids.extend(ids_to_send) - request_bytes = req_proto.SerializeToString() - - for i in range(3): - print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...") - socket.send(request_bytes) - response_bytes = socket.recv() - - resp_proto = embedding_pb2.NodeEmbeddingResponse() - resp_proto.ParseFromString(response_bytes) - embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0 - print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings") - time.sleep(0.1) - - print("Client-side Protobuf ZMQ warmup complete") - socket.close() - context.term() - except Exception as e: - print(f"Error during Protobuf ZMQ warmup: {e}") - - class DeviceTimer: - """Device timer""" - def __init__(self, name="", device=device): - self.name = name - self.device = device - self.start_time = 0 - self.end_time = 0 - - if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): - self.start_event = torch.cuda.Event(enable_timing=True) - self.end_event = torch.cuda.Event(enable_timing=True) - else: - self.start_event = None - self.end_event = None - - @contextmanager - def timing(self): - self.start() - yield - self.end() - - def start(self): - if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): - torch.cuda.synchronize() - self.start_event.record() - else: - if embedding_mode == "sentence-transformers" and self.device.type == "mps": - torch.mps.synchronize() - self.start_time = time.time() - - def end(self): - if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): - self.end_event.record() - torch.cuda.synchronize() - else: - if embedding_mode == "sentence-transformers" and self.device.type == "mps": - torch.mps.synchronize() - self.end_time = time.time() - - def elapsed_time(self): - if embedding_mode == "sentence-transformers" and torch.cuda.is_available(): - return self.start_event.elapsed_time(self.end_event) / 1000.0 - else: - return self.end_time - self.start_time - - def print_elapsed(self): - elapsed = self.elapsed_time() - print(f"[{self.name}] Elapsed time: {elapsed:.3f}s") - - def process_batch_pytorch(texts_batch, ids_batch, missing_ids): - """Process text batch""" - if not texts_batch: - return np.array([]) - - # Filter out empty texts and their corresponding IDs - valid_texts = [] - valid_ids = [] - for i, text in enumerate(texts_batch): - if text.strip(): # Only include non-empty texts - valid_texts.append(text) - valid_ids.append(ids_batch[i]) - - if not valid_texts: - print("WARNING: No valid texts in batch") - return np.array([]) - - # Tokenize - token_timer = DeviceTimer("tokenization") - with token_timer.timing(): - inputs = tokenizer( - valid_texts, - padding=True, - truncation=True, - max_length=512, - return_tensors="pt" - ).to(device) - - # Compute embeddings - embed_timer = DeviceTimer("embedding computation") - with embed_timer.timing(): - with torch.no_grad(): - outputs = model(**inputs) - hidden_states = outputs.last_hidden_state - - # Mean pooling - attention_mask = inputs['attention_mask'] - mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() - sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) - sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) - batch_embeddings = sum_embeddings / sum_mask - embed_timer.print_elapsed() - - return batch_embeddings.cpu().numpy() - - # ZMQ server main loop - modified to use REP socket - context = zmq.Context() - socket = context.socket(zmq.ROUTER) # Changed to REP socket - socket.bind(f"tcp://127.0.0.1:{zmq_port}") - print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}") - - # Set timeouts - socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout - socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout - - from . import embedding_pb2 - - print(f"INFO: Embedding server ready to serve requests") - - # Start warmup thread if enabled - if enable_warmup and len(passages) > 0: - import threading - print(f"Warmup enabled: starting warmup thread") - warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,)) - warmup_thread.daemon = True - warmup_thread.start() - else: - print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})") - - while True: - try: - parts = socket.recv_multipart() - - # --- Restore robust message format detection --- - # Must check parts length to avoid IndexError - if len(parts) >= 3: - identity = parts[0] - # empty = parts[1] # We usually don't care about the middle empty frame - message = parts[2] - elif len(parts) == 2: - # Can also handle cases without empty frame - identity = parts[0] - message = parts[1] - else: - # If received message format is wrong, print warning and ignore it instead of crashing - print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.") - continue - print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes") - - # Handle control messages (MessagePack format) - try: - request_payload = msgpack.unpackb(message) - if isinstance(request_payload, list) and len(request_payload) >= 1: - if request_payload[0] == "__QUERY_META_PATH__": - # Return the current meta path being used by the server - current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else '' - response = [current_meta_path] - socket.send_multipart([identity, b'', msgpack.packb(response)]) - continue - - elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2: - # Update the server's meta path and reload passages - new_meta_path = request_payload[1] - try: - print(f"INFO: Updating server meta path to: {new_meta_path}") - # Reload passages from the new meta file - passages = load_passages_from_metadata(new_meta_path) - # Store the meta path for future queries - passages._meta_path = new_meta_path - response = ["SUCCESS"] - print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages") - except Exception as e: - print(f"ERROR: Failed to update meta path: {e}") - response = ["FAILED", str(e)] - socket.send_multipart([identity, b'', msgpack.packb(response)]) - continue - - elif request_payload[0] == "__QUERY_MODEL__": - # Return the current model being used by the server - response = [model_name] - socket.send_multipart([identity, b'', msgpack.packb(response)]) - continue - - elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2: - # Update the server's embedding model - new_model_name = request_payload[1] - try: - print(f"INFO: Updating server model from {model_name} to: {new_model_name}") - - # Clean up old model to free memory - if not use_mlx: - print("INFO: Releasing old model from memory...") - old_model = model - old_tokenizer = tokenizer - - # Load new tokenizer first - print(f"Loading new tokenizer for {new_model_name}...") - tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True) - - # Load new model - print(f"Loading new model {new_model_name}...") - model = AutoModel.from_pretrained(new_model_name).to(device).eval() - - # Optimize new model - if cuda_available or mps_available: - try: - model = model.half() - model = torch.compile(model) - print(f"INFO: Using FP16 precision with model: {new_model_name}") - except Exception as e: - print(f"WARNING: Model optimization failed: {e}") - - # Now safely delete old model after new one is loaded - del old_model - del old_tokenizer - - # Clear GPU cache if available - if device.type == "cuda": - torch.cuda.empty_cache() - print("INFO: Cleared CUDA cache") - elif device.type == "mps": - torch.mps.empty_cache() - print("INFO: Cleared MPS cache") - - # Force garbage collection - import gc - gc.collect() - print("INFO: Memory cleanup completed") - - # Update model name - model_name = new_model_name - - response = ["SUCCESS"] - print(f"INFO: Successfully updated model to: {new_model_name}") - except Exception as e: - print(f"ERROR: Failed to update model: {e}") - response = ["FAILED", str(e)] - socket.send_multipart([identity, b'', msgpack.packb(response)]) - continue - except: - # Not a control message, continue with normal protobuf processing - pass - - e2e_start = time.time() - lookup_timer = DeviceTimer("text lookup") - - # Parse request - req_proto = embedding_pb2.NodeEmbeddingRequest() - req_proto.ParseFromString(message) - node_ids = req_proto.node_ids - print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}") - - # Add debug information - if len(node_ids) > 0: - print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}") - - # Look up texts - texts = [] - missing_ids = [] - with lookup_timer.timing(): - for nid in node_ids: - txtinfo = passages[nid] - txt = txtinfo["text"] - if txt: - texts.append(txt) - else: - # If text is empty, we still need a placeholder for batch processing, - # but record its ID as missing - texts.append("") - missing_ids.append(nid) - lookup_timer.print_elapsed() - - if missing_ids: - print(f"WARNING: Missing passages for IDs: {missing_ids}") - - # Process batch - total_size = len(texts) - print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}") - - all_embeddings = [] - - if total_size > max_batch_size: - print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}") - for i in range(0, total_size, max_batch_size): - end_idx = min(i + max_batch_size, total_size) - print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}") - - chunk_texts = texts[i:end_idx] - chunk_ids = node_ids[i:end_idx] - - if embedding_mode == "mlx": - embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16) - elif embedding_mode == "openai": - embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name) - else: # sentence-transformers - embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids) - all_embeddings.append(embeddings_chunk) - - if embedding_mode == "sentence-transformers": - if cuda_available: - torch.cuda.empty_cache() - elif device.type == "mps": - torch.mps.empty_cache() - - hidden = np.vstack(all_embeddings) - print(f"INFO: Combined embeddings shape: {hidden.shape}") - else: - if embedding_mode == "mlx": - hidden = compute_embeddings_mlx(texts, model_name, batch_size=16) - elif embedding_mode == "openai": - hidden = compute_embeddings_openai(texts, model_name) - else: # sentence-transformers - hidden = process_batch_pytorch(texts, node_ids, missing_ids) - - # Serialize response - ser_start = time.time() - - resp_proto = embedding_pb2.NodeEmbeddingResponse() - hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32) - resp_proto.embeddings_data = hidden_contiguous.tobytes() - resp_proto.dimensions.append(hidden_contiguous.shape[0]) - resp_proto.dimensions.append(hidden_contiguous.shape[1]) - resp_proto.missing_ids.extend(missing_ids) - - response_data = resp_proto.SerializeToString() - - # REP socket sends a single response - socket.send_multipart([identity, b'', response_data]) - - ser_end = time.time() - - print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds") - - if embedding_mode == "sentence-transformers": - if device.type == "cuda": - torch.cuda.synchronize() - elif device.type == "mps": - torch.mps.synchronize() - e2e_end = time.time() - print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds") - - except zmq.Again: - print("INFO: ZMQ socket timeout, continuing to listen") - continue - except Exception as e: - print(f"ERROR: Error in ZMQ server: {e}") - try: - # Send empty response to maintain REQ-REP state - empty_resp = embedding_pb2.NodeEmbeddingResponse() - socket.send(empty_resp.SerializeToString()) - except: - # If sending fails, recreate socket - socket.close() - socket = context.socket(zmq.REP) - socket.bind(f"tcp://127.0.0.1:{zmq_port}") - socket.setsockopt(zmq.RCVTIMEO, 5000) - socket.setsockopt(zmq.SNDTIMEO, 300000) - print("INFO: ZMQ socket recreated after error") - - except Exception as e: - print(f"ERROR: Failed to start embedding server: {e}") - raise - - -def create_embedding_server( - domain="demo", - load_passages=True, - load_embeddings=False, - use_fp16=True, - use_int8=False, - use_cuda_graphs=False, - zmq_port=5555, - max_batch_size=128, - lazy_load_passages=False, - model_name="sentence-transformers/all-mpnet-base-v2", - passages_file: Optional[str] = None, - embedding_mode: str = "sentence-transformers", - enable_warmup: bool = False, -): - """ - 原有的 create_embedding_server 函数保持不变 - 这个是阻塞版本,用于直接运行 - """ - create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Embedding service") - parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") - parser.add_argument("--domain", type=str, default="demo", help="Domain name") - parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping") - parser.add_argument("--load-passages", action="store_true", default=True) - parser.add_argument("--load-embeddings", action="store_true", default=False) - parser.add_argument("--use-fp16", action="store_true", default=False) - parser.add_argument("--use-int8", action="store_true", default=False) - parser.add_argument("--use-cuda-graphs", action="store_true", default=False) - parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting") - parser.add_argument("--lazy-load-passages", action="store_true", default=True) - parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", - help="Embedding model name") - parser.add_argument("--embedding-mode", type=str, default="sentence-transformers", - choices=["sentence-transformers", "mlx", "openai"], - help="Embedding backend mode") - parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)") - parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start") - args = parser.parse_args() - - # Handle backward compatibility with use_mlx - embedding_mode = args.embedding_mode - if args.use_mlx: - embedding_mode = "mlx" - - create_embedding_server( - domain=args.domain, - load_passages=args.load_passages, - load_embeddings=args.load_embeddings, - use_fp16=args.use_fp16, - use_int8=args.use_int8, - use_cuda_graphs=args.use_cuda_graphs, - zmq_port=args.zmq_port, - max_batch_size=args.max_batch_size, - lazy_load_passages=args.lazy_load_passages, - model_name=args.model_name, - passages_file=args.passages_file, - embedding_mode=embedding_mode, - enable_warmup=not args.disable_warmup, - ) diff --git a/packages/leann-backend-diskann/third_party/DiskANN b/packages/leann-backend-diskann/third_party/DiskANN index af2a264..25339b0 160000 --- a/packages/leann-backend-diskann/third_party/DiskANN +++ b/packages/leann-backend-diskann/third_party/DiskANN @@ -1 +1 @@ -Subproject commit af2a26481e65232b57b82d96e68833cdee9f7635 +Subproject commit 25339b03413b5067c25b6092ea3e0f77ef8515c8 diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index f1f8da0..0aa903e 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -142,12 +142,12 @@ class HNSWSearcher(BaseSearcher): self, query: np.ndarray, top_k: int, + zmq_port: Optional[int] = None, complexity: int = 64, beam_width: int = 1, prune_ratio: float = 0.0, recompute_embeddings: bool = True, pruning_strategy: Literal["global", "local", "proportional"] = "global", - expected_zmq_port: Optional[int] = None, batch_size: int = 0, **kwargs, ) -> Dict[str, Any]: @@ -165,7 +165,7 @@ class HNSWSearcher(BaseSearcher): - "global": Use global PQ queue size for selection (default) - "local": Local pruning, sort and select best candidates - "proportional": Base selection on new neighbor count ratio - expected_zmq_port: ZMQ port for embedding server + zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True. batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific) **kwargs: Additional HNSW-specific parameters (for legacy compatibility) @@ -177,6 +177,11 @@ class HNSWSearcher(BaseSearcher): if not recompute_embeddings: if self.is_pruned: raise RuntimeError("Recompute is required for pruned index.") + if recompute_embeddings: + if zmq_port is None: + raise ValueError( + "zmq_port must be provided if recompute_embeddings is True" + ) if query.dtype != np.float32: query = query.astype(np.float32) @@ -184,7 +189,10 @@ class HNSWSearcher(BaseSearcher): faiss.normalize_L2(query) params = faiss.SearchParametersHNSW() - params.zmq_port = expected_zmq_port + if zmq_port is not None: + params.zmq_port = ( + zmq_port # C++ code won't use this if recompute_embeddings is False + ) params.efSearch = complexity params.beam_size = beam_width diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index cbee3f7..111a52b 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -450,7 +450,7 @@ class LeannSearcher: prune_ratio=prune_ratio, recompute_embeddings=recompute_embeddings, pruning_strategy=pruning_strategy, - expected_zmq_port=zmq_port, + zmq_port=zmq_port, **kwargs, ) search_time = time.time() - start_time diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index ee2a59d..9f6be79 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -60,6 +60,9 @@ def compute_embeddings_sentence_transformers( """ Compute embeddings using SentenceTransformer with model caching """ + # Handle empty input + if not texts: + raise ValueError("Cannot compute embeddings for empty text list") logger.info( f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'" ) diff --git a/packages/leann-core/src/leann/interface.py b/packages/leann-core/src/leann/interface.py index 338c3dc..93a6ce8 100644 --- a/packages/leann-core/src/leann/interface.py +++ b/packages/leann-core/src/leann/interface.py @@ -64,7 +64,7 @@ class LeannBackendSearcherInterface(ABC): prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" - zmq_port: ZMQ port for embedding server communication + zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True. **kwargs: Backend-specific parameters Returns: diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index 8592ed4..6bd6ec8 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -104,6 +104,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): # Try to use embedding server if available and requested if use_server_if_available: try: + # TODO: Maybe we can directly use this port here? + # For this internal method, it's ok to assume that the server is running + # on that port? + # Ensure we have a server with passages_file for compatibility passages_source_file = ( self.index_dir / f"{self.index_path.name}.meta.json" @@ -181,7 +185,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0) recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional" - zmq_port: ZMQ port for embedding server communication + zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True. **kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.) Returns: