From 1b6272ce0ef6a846b2d1c17c3a8b5025b806e103 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 21 Jul 2025 20:17:25 -0700 Subject: [PATCH] Building, CLI tool & Embedding Server Fixed (#5) * chore: shorter build time * chore: update faiss * fix: no longger do embedding server reuse * fix: do not reuse emb_server and close it properly * feat: cli tool * feat: cli more args * fix: same embedding logic --- README.md | 4 +- packages/leann-backend-diskann/CMakeLists.txt | 6 +- .../leann_backend_diskann/diskann_backend.py | 9 +- .../leann_backend_diskann/embedding_server.py | 64 +- packages/leann-backend-diskann/pyproject.toml | 6 +- packages/leann-backend-hnsw/CMakeLists.txt | 36 +- .../leann_backend_hnsw/hnsw_backend.py | 16 +- .../hnsw_embedding_server.py | 1195 ++--------------- packages/leann-backend-hnsw/pyproject.toml | 6 +- packages/leann-backend-hnsw/third_party/faiss | 2 +- packages/leann-core/pyproject.toml | 3 + packages/leann-core/src/leann/api.py | 275 +--- packages/leann-core/src/leann/cli.py | 287 ++++ .../leann-core/src/leann/embedding_compute.py | 272 ++++ .../src/leann/embedding_server_manager.py | 564 ++++---- packages/leann-core/src/leann/registry.py | 25 +- .../leann-core/src/leann/searcher_base.py | 36 +- pyproject.toml | 1 + uv.lock | 16 +- 19 files changed, 1107 insertions(+), 1716 deletions(-) create mode 100644 packages/leann-core/src/leann/cli.py create mode 100644 packages/leann-core/src/leann/embedding_compute.py diff --git a/README.md b/README.md index dab40b5..1e0140c 100755 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ git submodule update --init --recursive **macOS:** ```bash -brew install llvm libomp boost protobuf +brew install llvm libomp boost protobuf zeromq export CC=$(brew --prefix llvm)/bin/clang export CXX=$(brew --prefix llvm)/bin/clang++ @@ -61,7 +61,7 @@ uv sync --extra diskann **Linux (Ubuntu/Debian):** ```bash -sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev +sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev # Install with HNSW backend (default, recommended for most users) uv sync diff --git a/packages/leann-backend-diskann/CMakeLists.txt b/packages/leann-backend-diskann/CMakeLists.txt index ee9d932..2638282 100644 --- a/packages/leann-backend-diskann/CMakeLists.txt +++ b/packages/leann-backend-diskann/CMakeLists.txt @@ -1,8 +1,8 @@ -# packages/leann-backend-diskann/CMakeLists.txt (最终简化版) +# packages/leann-backend-diskann/CMakeLists.txt (simplified version) cmake_minimum_required(VERSION 3.20) project(leann_backend_diskann_wrapper) -# 告诉 CMake 直接进入 DiskANN 子模块并执行它自己的 CMakeLists.txt -# DiskANN 会自己处理所有事情,包括编译 Python 绑定 +# Tell CMake to directly enter the DiskANN submodule and execute its own CMakeLists.txt +# DiskANN will handle everything itself, including compiling Python bindings add_subdirectory(src/third_party/DiskANN) diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index adf9182..bbd042d 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -70,10 +70,6 @@ class DiskannBuilder(LeannBackendBuilderInterface): data_filename = f"{index_prefix}_data.bin" _write_vectors_to_bin(data, index_dir / data_filename) - label_map = {i: str_id for i, str_id in enumerate(ids)} - label_map_file = index_dir / "leann.labels.map" - with open(label_map_file, "wb") as f: - pickle.dump(label_map, f) build_kwargs = {**self.build_params, **kwargs} metric_enum = _get_diskann_metrics().get( @@ -211,10 +207,7 @@ class DiskannSearcher(BaseSearcher): ) string_labels = [ - [ - self.label_map.get(int_label, f"unknown_{int_label}") - for int_label in batch_labels - ] + [str(int_label) for int_label in batch_labels] for batch_labels in labels ] diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index 089ec1f..04f7f56 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -76,24 +76,11 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: finally: sys.path.pop(0) - # Load label map - passages_dir = Path(meta_file).parent - label_map_file = passages_dir / "leann.labels.map" - - if label_map_file.exists(): - import pickle - with open(label_map_file, 'rb') as f: - label_map = pickle.load(f) - print(f"Loaded label map with {len(label_map)} entries") - else: - raise FileNotFoundError(f"Label map file not found: {label_map_file}") - - print(f"Initialized lazy passage loading for {len(label_map)} passages") + print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages") class LazyPassageLoader(SimplePassageLoader): - def __init__(self, passage_manager, label_map): + def __init__(self, passage_manager): self.passage_manager = passage_manager - self.label_map = label_map # Initialize parent with empty data super().__init__({}) @@ -101,25 +88,22 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: """Get passage by ID with lazy loading""" try: int_id = int(passage_id) - if int_id in self.label_map: - string_id = self.label_map[int_id] - passage_data = self.passage_manager.get_passage(string_id) - if passage_data and passage_data.get("text"): - return {"text": passage_data["text"]} - else: - raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}") + string_id = str(int_id) + passage_data = self.passage_manager.get_passage(string_id) + if passage_data and passage_data.get("text"): + return {"text": passage_data["text"]} else: - raise RuntimeError(f"FATAL: ID {int_id} not found in label_map") + raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}") except Exception as e: raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}") def __len__(self) -> int: - return len(self.label_map) + return len(self.passage_manager.global_offset_map) def keys(self): - return self.label_map.keys() + return self.passage_manager.global_offset_map.keys() - loader = LazyPassageLoader(passage_manager, label_map) + loader = LazyPassageLoader(passage_manager) loader._meta_path = meta_file return loader @@ -135,35 +119,15 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader: if not passages_file.endswith('.jsonl'): raise ValueError(f"Expected .jsonl file format, got: {passages_file}") - # Load label map (int -> string_id) - passages_dir = Path(passages_file).parent - label_map_file = passages_dir / "leann.labels.map" - - label_map = {} - if label_map_file.exists(): - with open(label_map_file, 'rb') as f: - label_map = pickle.load(f) - print(f"Loaded label map with {len(label_map)} entries") - else: - raise FileNotFoundError(f"Label map file not found: {label_map_file}") - - # Load passages by string ID - string_id_passages = {} + # Load passages directly by their sequential IDs + passages_data = {} with open(passages_file, 'r', encoding='utf-8') as f: for line in f: if line.strip(): passage = json.loads(line) - string_id_passages[passage['id']] = passage['text'] + passages_data[passage['id']] = passage['text'] - # Create int ID -> text mapping using label map - passages_data = {} - for int_id, string_id in label_map.items(): - if string_id in string_id_passages: - passages_data[str(int_id)] = string_id_passages[string_id] - else: - print(f"WARNING: String ID {string_id} from label map not found in passages") - - print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map") + print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}") return SimplePassageLoader(passages_data) def create_embedding_server_thread( diff --git a/packages/leann-backend-diskann/pyproject.toml b/packages/leann-backend-diskann/pyproject.toml index 8f02feb..24637cc 100644 --- a/packages/leann-backend-diskann/pyproject.toml +++ b/packages/leann-backend-diskann/pyproject.toml @@ -8,11 +8,11 @@ version = "0.1.0" dependencies = ["leann-core==0.1.0", "numpy"] [tool.scikit-build] -# 关键:简化的 CMake 路径 +# Key: simplified CMake path cmake.source-dir = "third_party/DiskANN" -# 关键:Python 包在根目录,路径完全匹配 +# Key: Python package in root directory, paths match exactly wheel.packages = ["leann_backend_diskann"] -# 使用默认的 redirect 模式 +# Use default redirect mode editable.mode = "redirect" cmake.build-type = "Release" build.verbose = true diff --git a/packages/leann-backend-hnsw/CMakeLists.txt b/packages/leann-backend-hnsw/CMakeLists.txt index bcadd12..2b86b0a 100644 --- a/packages/leann-backend-hnsw/CMakeLists.txt +++ b/packages/leann-backend-hnsw/CMakeLists.txt @@ -1,6 +1,7 @@ -# 最终简化版 cmake_minimum_required(VERSION 3.24) project(leann_backend_hnsw_wrapper) +set(CMAKE_C_COMPILER_WORKS 1) +set(CMAKE_CXX_COMPILER_WORKS 1) # Set OpenMP path for macOS if(APPLE) @@ -11,15 +12,9 @@ if(APPLE) set(OpenMP_omp_LIBRARY "/opt/homebrew/opt/libomp/lib/libomp.dylib") endif() -# Build ZeroMQ from source -set(ZMQ_BUILD_TESTS OFF CACHE BOOL "" FORCE) -set(ENABLE_DRAFTS OFF CACHE BOOL "" FORCE) -set(ENABLE_PRECOMPILED OFF CACHE BOOL "" FORCE) -set(WITH_PERF_TOOL OFF CACHE BOOL "" FORCE) -set(WITH_DOCS OFF CACHE BOOL "" FORCE) -set(BUILD_SHARED OFF CACHE BOOL "" FORCE) -set(BUILD_STATIC ON CACHE BOOL "" FORCE) -add_subdirectory(third_party/libzmq) +# Use system ZeroMQ instead of building from source +find_package(PkgConfig REQUIRED) +pkg_check_modules(ZMQ REQUIRED libzmq) # Add cppzmq headers include_directories(third_party/cppzmq) @@ -29,6 +24,7 @@ set(MSGPACK_USE_BOOST OFF CACHE BOOL "" FORCE) add_compile_definitions(MSGPACK_NO_BOOST) include_directories(third_party/msgpack-c/include) +# Faiss configuration - streamlined build set(FAISS_ENABLE_PYTHON ON CACHE BOOL "" FORCE) set(FAISS_ENABLE_GPU OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_EXTRAS OFF CACHE BOOL "" FORCE) @@ -36,4 +32,24 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE) set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE) set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE) +# Disable additional SIMD versions to speed up compilation +set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE) +set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE) + +# Additional optimization options from INSTALL.md +set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE) +set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) # Static library is faster to build + +# Avoid building demos and benchmarks +set(BUILD_DEMOS OFF CACHE BOOL "" FORCE) +set(BUILD_BENCHS OFF CACHE BOOL "" FORCE) + +# NEW: Tell Faiss to only build the generic version +set(FAISS_BUILD_GENERIC ON CACHE BOOL "" FORCE) +set(FAISS_BUILD_AVX2 OFF CACHE BOOL "" FORCE) +set(FAISS_BUILD_AVX512 OFF CACHE BOOL "" FORCE) + +# IMPORTANT: Disable building AVX versions to speed up compilation +set(FAISS_BUILD_AVX_VERSIONS OFF CACHE BOOL "" FORCE) + add_subdirectory(third_party/faiss) \ No newline at end of file diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index 3ceda37..b7061e1 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -59,10 +59,6 @@ class HNSWBuilder(LeannBackendBuilderInterface): if data.dtype != np.float32: data = data.astype(np.float32) - label_map = {i: str_id for i, str_id in enumerate(ids)} - label_map_file = index_dir / "leann.labels.map" - with open(label_map_file, "wb") as f: - pickle.dump(label_map, f) metric_enum = get_metric_map().get(self.distance_metric.lower()) if metric_enum is None: @@ -142,13 +138,6 @@ class HNSWSearcher(BaseSearcher): self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) - # Load label mapping - label_map_file = self.index_dir / "leann.labels.map" - if not label_map_file.exists(): - raise FileNotFoundError(f"Label map file not found at {label_map_file}") - - with open(label_map_file, "rb") as f: - self.label_map = pickle.load(f) def search( self, @@ -239,10 +228,7 @@ class HNSWSearcher(BaseSearcher): ) string_labels = [ - [ - self.label_map.get(int_label, f"unknown_{int_label}") - for int_label in batch_labels - ] + [str(int_label) for int_label in batch_labels] for batch_labels in labels ] diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 579f2bb..48f8e1b 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -1,16 +1,11 @@ -#!/usr/bin/env python3 """ -HNSW-specific embedding server with removed config.py dependencies -Based on DiskANN embedding server architecture +HNSW-specific embedding server """ -import pickle import argparse import threading import time -from transformers import AutoTokenizer, AutoModel import os -from contextlib import contextmanager import zmq import numpy as np import msgpack @@ -24,279 +19,50 @@ RED = "\033[91m" RESET = "\033[0m" # Set up logging based on environment variable -LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper() +LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "INFO").upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL, logging.INFO), - format='%(asctime)s - %(levelname)s - %(message)s' + format="%(asctime)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) -def is_similarity_metric(): - """ - Check if the metric type is similarity-based (like inner product). - 0 = L2 (distance metric), 1 = Inner Product (similarity metric) - """ - return True # 1 is METRIC_INNER_PRODUCT in FAISS - - -# Function for E5-style average pooling -import torch -from torch import Tensor -import torch.nn.functional as F - -# Timing utilities -@contextmanager -def timer(name: str, sync_cuda: bool = True): - """Context manager for timing operations with optional CUDA sync""" - start_time = time.time() - if sync_cuda and torch.cuda.is_available(): - torch.cuda.synchronize() - try: - yield - finally: - if sync_cuda and torch.cuda.is_available(): - torch.cuda.synchronize() - elif sync_cuda and torch.backends.mps.is_available(): - torch.mps.synchronize() - elapsed = time.time() - start_time - logger.info(f"⏱️ {name}: {elapsed:.4f}s") - - -def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: - last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) - return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] - - -class SimplePassageLoader: - """ - Simple passage loader that replaces config.py dependencies - """ - - def __init__(self, passages_data: Optional[Dict[str, Any]] = None): - self.passages_data = passages_data or {} - self._meta_path = "" - - def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: - """Get passage by ID""" - str_id = str(passage_id) - if str_id in self.passages_data: - return {"text": self.passages_data[str_id]} - else: - # Return empty text for missing passages - return {"text": ""} - - def __len__(self) -> int: - return len(self.passages_data) - - def keys(self): - return self.passages_data.keys() - - -def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader: - """ - Load passages using metadata file with PassageManager for lazy loading - """ - # Load metadata to get passage sources - with open(meta_file, "r") as f: - meta = json.load(f) - - # Import PassageManager dynamically to avoid circular imports - # Find the leann package directory relative to this file - current_dir = Path(__file__).parent - leann_core_path = current_dir.parent.parent / "leann-core" / "src" - sys.path.insert(0, str(leann_core_path)) - - try: - from leann.api import PassageManager - - passage_manager = PassageManager(meta["passage_sources"]) - finally: - sys.path.pop(0) - - # Load label map - passages_dir = Path(meta_file).parent - label_map_file = passages_dir / "leann.labels.map" - - if label_map_file.exists(): - import pickle - - with open(label_map_file, "rb") as f: - label_map = pickle.load(f) - print(f"Loaded label map with {len(label_map)} entries") - else: - raise FileNotFoundError(f"Label map file not found: {label_map_file}") - - print(f"Initialized lazy passage loading for {len(label_map)} passages") - - class LazyPassageLoader(SimplePassageLoader): - def __init__(self, passage_manager, label_map): - self.passage_manager = passage_manager - self.label_map = label_map - # Initialize parent with empty data - super().__init__({}) - - def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]: - """Get passage by ID with lazy loading""" - try: - int_id = int(passage_id) - if int_id in self.label_map: - string_id = self.label_map[int_id] - passage_data = self.passage_manager.get_passage(string_id) - if passage_data and passage_data.get("text"): - return {"text": passage_data["text"]} - else: - logger.debug(f"Empty text for ID {int_id} -> {string_id}") - return {"text": ""} - else: - logger.debug(f"ID {int_id} not found in label_map") - return {"text": ""} - except Exception as e: - logger.debug(f"Exception getting passage {passage_id}: {e}") - return {"text": ""} - - def __len__(self) -> int: - return len(self.label_map) - - def keys(self): - return self.label_map.keys() - - return LazyPassageLoader(passage_manager, label_map) - - def create_hnsw_embedding_server( passages_file: Optional[str] = None, passages_data: Optional[Dict[str, str]] = None, - embeddings_file: Optional[str] = None, - use_fp16: bool = True, - use_int8: bool = False, - use_cuda_graphs: bool = False, zmq_port: int = 5555, - max_batch_size: int = 128, model_name: str = "sentence-transformers/all-mpnet-base-v2", - custom_max_length_param: Optional[int] = None, distance_metric: str = "mips", embedding_mode: str = "sentence-transformers", - enable_warmup: bool = False, ): """ Create and start a ZMQ-based embedding server for HNSW backend. - - Args: - passages_file: Path to JSON file containing passage ID -> text mapping - passages_data: Direct passage data dict (alternative to passages_file) - embeddings_file: Path to pre-computed embeddings file (optional) - use_fp16: Whether to use FP16 precision - use_int8: Whether to use INT8 quantization - use_cuda_graphs: Whether to use CUDA graphs - zmq_port: ZMQ port to bind to - max_batch_size: Maximum batch size for processing - model_name: Transformer model name - custom_max_length_param: Custom max sequence length - distance_metric: The distance metric to use - enable_warmup: Whether to perform warmup requests on server start + Simplified version using unified embedding computation module. """ - # Handle different embedding modes directly in HNSW server - # Auto-detect mode based on model name if not explicitly set - if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"): + if embedding_mode == "sentence-transformers" and model_name.startswith( + "text-embedding-" + ): embedding_mode = "openai" - - if embedding_mode == "openai": - print(f"Using OpenAI API mode for {model_name}") - tokenizer = None # No local tokenizer needed for OpenAI API - elif embedding_mode == "mlx": - print(f"Using MLX mode for {model_name}") - tokenizer = None # MLX handles tokenization separately - else: # sentence-transformers - print(f"Loading tokenizer for {model_name}...") - # Optimized tokenizer loading: try local first, then fallback - try: - tokenizer = AutoTokenizer.from_pretrained( - model_name, - use_fast=True, # Use fast tokenizer (better runtime perf) - local_files_only=True # Avoid network delays - ) - print(f"Tokenizer loaded successfully! (local + fast)") - except Exception as e: - print(f"Local tokenizer failed ({e}), trying network download...") - tokenizer = AutoTokenizer.from_pretrained( - model_name, - use_fast=True # Use fast tokenizer - ) - print(f"Tokenizer loaded successfully! (network)") - # Device setup - mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() - cuda_available = torch.cuda.is_available() - - print(f"MPS available: {mps_available}") - print(f"CUDA available: {cuda_available}") - - if cuda_available: - device = torch.device("cuda") - print("Using CUDA device") - elif mps_available: - device = torch.device("mps") - print("Using MPS device (Apple Silicon)") - else: - device = torch.device("cpu") - print("Using CPU device (no GPU acceleration available)") - - # Load model to the appropriate device print(f"Starting HNSW server on port {zmq_port} with model {model_name}") - print(f"Loading model {model_name}... (this may take a while if downloading)") + print(f"Using embedding mode: {embedding_mode}") - if embedding_mode == "mlx": - # For MLX models, we need to use the MLX embedding computation - print("MLX model detected - using MLX backend for embeddings") - model = None # We'll handle MLX separately - elif embedding_mode == "openai": - # For OpenAI API, no local model needed - print("OpenAI API mode - no local model loading required") - model = None - else: - # Use optimized transformers loading for sentence-transformers models - print(f"Loading model with optimizations...") - try: - # Ultra-fast loading: preload config + fast_init - from transformers import AutoConfig - config = AutoConfig.from_pretrained(model_name, local_files_only=True) - model = AutoModel.from_pretrained( - model_name, - config=config, - torch_dtype=torch.float16, # Half precision for speed - low_cpu_mem_usage=True, # Reduce memory peaks - local_files_only=True, # Avoid network delays - _fast_init=True # Skip weight init checks - ).to(device).eval() - print(f"Model {model_name} loaded successfully! (ultra-fast)") - except Exception as e: - print(f"Ultra-fast loading failed ({e}), trying optimized...") - try: - # Fallback: regular optimized loading - model = AutoModel.from_pretrained( - model_name, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - local_files_only=True - ).to(device).eval() - print(f"Model {model_name} loaded successfully! (optimized)") - except Exception as e2: - print(f"Optimized loading failed ({e2}), trying network...") - try: - # Fallback: optimized network loading - model = AutoModel.from_pretrained( - model_name, - torch_dtype=torch.float16, - low_cpu_mem_usage=True - ).to(device).eval() - print(f"Model {model_name} loaded successfully! (network + optimized)") - except Exception as e3: - print(f"All optimized methods failed ({e3}), using standard...") - # Final fallback: standard loading - model = AutoModel.from_pretrained(model_name).to(device).eval() - print(f"Model {model_name} loaded successfully! (standard)") + # Add leann-core to path for unified embedding computation + current_dir = Path(__file__).parent + leann_core_path = current_dir.parent.parent / "leann-core" / "src" + sys.path.insert(0, str(leann_core_path)) + + try: + from leann.embedding_compute import compute_embeddings + from leann.api import PassageManager + + print("Successfully imported unified embedding computation module") + except ImportError as e: + print(f"ERROR: Failed to import embedding computation module: {e}") + return + finally: + sys.path.pop(0) # Check port availability import socket @@ -309,314 +75,19 @@ def create_hnsw_embedding_server( print(f"{RED}Port {zmq_port} is already in use{RESET}") return - # Apply model optimizations (similar to DiskANN version) - if use_fp16 and (cuda_available or mps_available): - model = model.half() - model = torch.compile(model) - print(f"Using FP16 precision with model: {model_name}") - elif use_int8: - print( - "- Using TorchAO for Int8 dynamic activation and Int8 weight quantization" - ) - from torchao.quantization import ( - quantize_, - Int8DynamicActivationInt8WeightConfig, - ) - - quantize_(model, Int8DynamicActivationInt8WeightConfig()) - model = torch.compile(model) - model.eval() - print("- Model successfully quantized and compiled") - - # Load passages - if passages_data: - passages = SimplePassageLoader(passages_data) - print(f"Using provided passages data: {len(passages)} passages") - elif passages_file: - # Check if it's a metadata file or a single passages file - if passages_file.endswith(".meta.json"): - passages = load_passages_from_metadata(passages_file) - # Store the meta path for future reference - passages._meta_path = passages_file - else: - # Try to find metadata file in same directory - passages_dir = Path(passages_file).parent - meta_files = list(passages_dir.glob("*.meta.json")) - if meta_files: - print(f"Found metadata file: {meta_files[0]}, using lazy loading") - passages = load_passages_from_metadata(str(meta_files[0])) - else: - # Fallback to original single file loading (will cause warnings) - print( - "WARNING: No metadata file found, using single file loading (may cause missing passage warnings)" - ) - passages = ( - SimplePassageLoader() - ) # Use empty loader to avoid massive warnings - else: - passages = SimplePassageLoader() - print("No passages provided, using empty loader") - - # Load embeddings if provided - _embeddings = None - if embeddings_file and os.path.exists(embeddings_file): - try: - with open(embeddings_file, "rb") as f: - _embeddings = pickle.load(f) - print(f"Loaded embeddings from {embeddings_file}") - except Exception as e: - print(f"Error loading embeddings: {e}") - - class DeviceTimer: - """Device event-based timer for accurate timing.""" - - def __init__(self, name="", device=device): - self.name = name - self.device = device - self.start_time = 0 - self.end_time = 0 - - if cuda_available: - self.start_event = torch.cuda.Event(enable_timing=True) - self.end_event = torch.cuda.Event(enable_timing=True) - else: - self.start_event = None - self.end_event = None - - @contextmanager - def timing(self): - self.start() - yield - self.end() - - def start(self): - if cuda_available: - torch.cuda.synchronize() - self.start_event.record() - else: - if self.device.type == "mps": - torch.mps.synchronize() - self.start_time = time.time() - - def end(self): - if cuda_available: - self.end_event.record() - torch.cuda.synchronize() - else: - if self.device.type == "mps": - torch.mps.synchronize() - self.end_time = time.time() - - def elapsed_time(self): - if cuda_available: - return self.start_event.elapsed_time(self.end_event) / 1000.0 - else: - return self.end_time - self.start_time - - def print_elapsed(self): - return # Disabled for now - - def _process_batch_mlx(texts_batch, ids_batch, missing_ids): - """Process a batch of texts using MLX backend""" - try: - # Import MLX embedding computation from main API - from leann.api import compute_embeddings - - # Compute embeddings using MLX - embeddings = compute_embeddings(texts_batch, model_name, mode="mlx", use_server=False) - - print( - f"[leann_backend_hnsw.hnsw_embedding_server LOG]: MLX embeddings computed for {len(texts_batch)} texts" - ) - print( - f"[leann_backend_hnsw.hnsw_embedding_server LOG]: Embedding shape: {embeddings.shape}" - ) - - return embeddings - - except Exception as e: - print( - f"[leann_backend_hnsw.hnsw_embedding_server LOG]: ERROR in MLX processing: {e}" - ) - raise - - def process_batch(texts_batch, ids_batch, missing_ids): - """Process a batch of texts and return embeddings""" - - # Handle different embedding modes - if embedding_mode == "mlx": - return _process_batch_mlx(texts_batch, ids_batch, missing_ids) - elif embedding_mode == "openai": - with timer("OpenAI API call", sync_cuda=False): - from leann.api import compute_embeddings_openai - return compute_embeddings_openai(texts_batch, model_name) - - _is_e5_model = "e5" in model_name.lower() - _is_bge_model = "bge" in model_name.lower() - batch_size = len(texts_batch) - - # Allow empty texts to pass through (remove validation) - - # E5 model preprocessing - if _is_e5_model: - processed_texts_batch = [f"passage: {text}" for text in texts_batch] - else: - processed_texts_batch = texts_batch - - # Set max length - if _is_e5_model: - current_max_length = ( - custom_max_length_param if custom_max_length_param is not None else 512 - ) - else: - current_max_length = ( - custom_max_length_param if custom_max_length_param is not None else 256 - ) - - tokenize_timer = DeviceTimer("tokenization (batch)", device) - to_device_timer = DeviceTimer("transfer to device (batch)", device) - embed_timer = DeviceTimer("embedding (batch)", device) - pool_timer = DeviceTimer("pooling (batch)", device) - norm_timer = DeviceTimer("normalization (batch)", device) - - with tokenize_timer.timing(): - encoded_batch = tokenizer( - processed_texts_batch, - padding="max_length", - truncation=True, - max_length=current_max_length, - return_tensors="pt", - return_token_type_ids=False, - ) - - seq_length = encoded_batch["input_ids"].size(1) - - with to_device_timer.timing(): - enc = {k: v.to(device) for k, v in encoded_batch.items()} - - with torch.no_grad(): - with timer("Model forward pass"): - with embed_timer.timing(): - out = model(enc["input_ids"], enc["attention_mask"]) - - with timer("Pooling"): - with pool_timer.timing(): - if _is_bge_model: - pooled_embeddings = out.last_hidden_state[:, 0] - elif not hasattr(out, "last_hidden_state"): - if isinstance(out, torch.Tensor) and len(out.shape) == 2: - pooled_embeddings = out - else: - print( - f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}" - ) - hidden_dim = getattr( - model.config, "hidden_size", 384 if _is_e5_model else 768 - ) - pooled_embeddings = torch.zeros( - (batch_size, hidden_dim), - device=device, - dtype=enc["input_ids"].dtype - if hasattr(enc["input_ids"], "dtype") - else torch.float32, - ) - elif _is_e5_model: - pooled_embeddings = e5_average_pool( - out.last_hidden_state, enc["attention_mask"] - ) - else: - hidden_states = out.last_hidden_state - mask_expanded = ( - enc["attention_mask"] - .unsqueeze(-1) - .expand(hidden_states.size()) - .float() - ) - sum_embeddings = torch.sum(hidden_states * mask_expanded, 1) - sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) - pooled_embeddings = sum_embeddings / sum_mask - - final_embeddings = pooled_embeddings - if _is_e5_model or _is_bge_model: - with norm_timer.timing(): - final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1) - - if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any(): - print( - f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! " - f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}" - ) - dim_size = final_embeddings.shape[-1] - error_output = torch.zeros( - (batch_size, dim_size), device="cpu", dtype=torch.float32 - ).numpy() - print( - f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}" - ) - return error_output - - return final_embeddings.cpu().numpy() - - def client_warmup(zmq_port): - """Perform client-side warmup""" - time.sleep(2) - print(f"Performing client-side warmup with model {model_name}...") - - # Get actual passage IDs from the loaded passages - sample_ids = [] - if hasattr(passages, 'keys') and len(passages) > 0: - available_ids = list(passages.keys()) - # Take up to 5 actual IDs, but at least 1 - sample_ids = available_ids[:min(5, len(available_ids))] - print(f"Using actual passage IDs for warmup: {sample_ids}") - else: - print("No passages available for warmup, skipping warmup...") - return - - try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.connect(f"tcp://localhost:{zmq_port}") - socket.setsockopt(zmq.RCVTIMEO, 30000) - socket.setsockopt(zmq.SNDTIMEO, 30000) - - try: - ids_to_send = [int(x) for x in sample_ids] - except ValueError: - print("Warning: Could not convert sample IDs to integers, skipping warmup") - return - - if not ids_to_send: - print("Skipping warmup send.") - return - - request_payload = [ids_to_send] - request_bytes = msgpack.packb(request_payload) - - for i in range(3): - print(f"Sending warmup request {i + 1}/3 via ZMQ (MessagePack)...") - socket.send(request_bytes) - response_bytes = socket.recv() - - response_payload = msgpack.unpackb(response_bytes) - dimensions = response_payload[0] - embeddings_count = ( - dimensions[0] if dimensions and len(dimensions) > 0 else 0 - ) - print( - f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings" - ) - time.sleep(0.1) - - print("Client-side MessagePack ZMQ warmup complete") - socket.close() - context.term() - except Exception as e: - print(f"Error during MessagePack ZMQ warmup: {e}") + # Only support metadata file, fail fast for everything else + if not passages_file or not passages_file.endswith(".meta.json"): + raise ValueError("Only metadata files (.meta.json) are supported") + + # Load metadata to get passage sources + with open(passages_file, "r") as f: + meta = json.load(f) + + passages = PassageManager(meta["passage_sources"]) + print(f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata") def zmq_server_thread(): """ZMQ server thread""" - nonlocal passages, model, tokenizer, model_name, embedding_mode context = zmq.Context() socket = context.socket(zmq.REP) socket.bind(f"tcp://*:{zmq_port}") @@ -631,457 +102,139 @@ def create_hnsw_embedding_server( print(f"Received ZMQ request of size {len(message_bytes)} bytes") e2e_start = time.time() - lookup_timer = DeviceTimer("text lookup", device) + request_payload = msgpack.unpackb(message_bytes) - try: - request_payload = msgpack.unpackb(message_bytes) - if isinstance(request_payload, list): - logger.debug(f"request_payload length: {len(request_payload)}") - for i, item in enumerate(request_payload): - print( - f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}" - ) + # Handle direct text embedding request (for OpenAI and sentence-transformers) + if isinstance(request_payload, list) and len(request_payload) > 0: + # Check if this is a direct text request (list of strings) + if all(isinstance(item, str) for item in request_payload): + logger.info( + f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode" + ) - # Handle control messages for meta path and model management FIRST - if isinstance(request_payload, list) and len(request_payload) >= 1: - if request_payload[0] == "__QUERY_META_PATH__": - # Return the current meta path being used by the server - current_meta_path = ( - getattr(passages, "_meta_path", "") - if hasattr(passages, "_meta_path") - else "" - ) - response = [current_meta_path] - socket.send(msgpack.packb(response)) - continue + # Use unified embedding computation + embeddings = compute_embeddings( + request_payload, model_name, mode=embedding_mode + ) - elif ( - request_payload[0] == "__UPDATE_META_PATH__" - and len(request_payload) >= 2 - ): - # Update the server's meta path and reload passages - new_meta_path = request_payload[1] - try: - print( - f"INFO: Updating server meta path to: {new_meta_path}" - ) - # Reload passages from the new meta file - passages = load_passages_from_metadata(new_meta_path) - # Store the meta path for future queries - passages._meta_path = new_meta_path - response = ["SUCCESS"] - print( - f"INFO: Successfully updated meta path and reloaded {len(passages)} passages" - ) - except Exception as e: - print(f"ERROR: Failed to update meta path: {e}") - response = ["FAILED", str(e)] - socket.send(msgpack.packb(response)) - continue - - elif request_payload[0] == "__QUERY_MODEL__": - # Return the current model being used by the server - response = [model_name] - socket.send(msgpack.packb(response)) - continue - - elif ( - request_payload[0] == "__UPDATE_MODEL__" - and len(request_payload) >= 2 - ): - # Update the server's embedding model - new_model_name = request_payload[1] - try: - print( - f"INFO: Updating server model from {model_name} to: {new_model_name}" - ) - - # Clean up old model to free memory - logger.info("Releasing old model from memory...") - old_model = model - old_tokenizer = tokenizer - - # Load new tokenizer first (optimized) - print(f"Loading new tokenizer for {new_model_name}...") - try: - tokenizer = AutoTokenizer.from_pretrained( - new_model_name, - use_fast=True, - local_files_only=True - ) - print(f"New tokenizer loaded! (local + fast)") - except: - tokenizer = AutoTokenizer.from_pretrained( - new_model_name, - use_fast=True - ) - print(f"New tokenizer loaded! (network + fast)") - - # Load new model (optimized) - print(f"Loading new model {new_model_name}...") - try: - # Ultra-fast model switching - from transformers import AutoConfig - config = AutoConfig.from_pretrained(new_model_name, local_files_only=True) - model = AutoModel.from_pretrained( - new_model_name, - config=config, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - local_files_only=True, - _fast_init=True - ) - print(f"New model loaded! (ultra-fast)") - except: - try: - model = AutoModel.from_pretrained( - new_model_name, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - local_files_only=True - ) - print(f"New model loaded! (optimized)") - except: - try: - model = AutoModel.from_pretrained( - new_model_name, - torch_dtype=torch.float16, - low_cpu_mem_usage=True - ) - print(f"New model loaded! (network + optimized)") - except: - model = AutoModel.from_pretrained(new_model_name) - print(f"New model loaded! (standard)") - model.to(device) - model.eval() - - # Now safely delete old model after new one is loaded - del old_model - del old_tokenizer - - # Clear GPU cache if available - if device.type == "cuda": - torch.cuda.empty_cache() - logger.info("Cleared CUDA cache") - elif device.type == "mps": - torch.mps.empty_cache() - logger.info("Cleared MPS cache") - - # Update model name - model_name = new_model_name - - # Re-detect embedding mode based on new model name - if model_name.startswith("text-embedding-"): - embedding_mode = "openai" - logger.info(f"Auto-detected embedding mode: openai for {model_name}") - else: - embedding_mode = "sentence-transformers" - logger.info(f"Auto-detected embedding mode: sentence-transformers for {model_name}") - - # Force garbage collection - import gc - - gc.collect() - logger.info("Memory cleanup completed") - - response = ["SUCCESS"] - print( - f"INFO: Successfully updated model to: {new_model_name}" - ) - except Exception as e: - print(f"ERROR: Failed to update model: {e}") - response = ["FAILED", str(e)] - socket.send(msgpack.packb(response)) - continue - - # Handle direct text embedding request (for OpenAI and sentence-transformers) - if isinstance(request_payload, list) and len(request_payload) > 0: - # Check if this is a direct text request (list of strings) and NOT a control message - if (all(isinstance(item, str) for item in request_payload) and - not request_payload[0].startswith("__")): - logger.info(f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode") - - try: - if embedding_mode == "openai": - from leann.api import compute_embeddings_openai - embeddings = compute_embeddings_openai(request_payload, model_name) - else: - # sentence-transformers mode - compute directly - with timer(f"Direct text embedding ({len(request_payload)} texts)"): - embeddings = process_batch(request_payload, [], []) - - response = embeddings.tolist() - socket.send(msgpack.packb(response)) - e2e_end = time.time() - logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") - continue - except Exception as e: - logger.error(f"ERROR: Failed to compute {embedding_mode} embeddings: {e}") - socket.send(msgpack.packb([])) - continue - - # Handle distance calculation requests - if ( - isinstance(request_payload, list) - and len(request_payload) == 2 - and isinstance(request_payload[0], list) - and isinstance(request_payload[1], list) - ): - node_ids = request_payload[0] - query_vector = np.array(request_payload[1], dtype=np.float32) - - logger.debug("Distance calculation request received") - print(f" Node IDs: {node_ids}") - print(f" Query vector dim: {len(query_vector)}") - print(f" Passages loaded: {len(passages)}") - - # Get embeddings for node IDs - texts = [] - missing_ids = [] - with lookup_timer.timing(): - for nid in node_ids: - logger.debug(f"Looking up passage ID {nid}") - try: - txtinfo = passages[nid] - if txtinfo is None: - print( - f"ERROR: Passage with ID {nid} returned None" - ) - print(f"ERROR: txtinfo: {txtinfo}") - raise RuntimeError( - f"FATAL: Passage with ID {nid} returned None" - ) - txt = txtinfo[ - "text" - ] # Allow empty text to pass through - print( - f"DEBUG: Found text for ID {nid}, length: {len(txt)}" - ) - texts.append(txt) - except KeyError: - print( - f"ERROR: Passage ID {nid} not found in passages dict" - ) - print( - f"ERROR: Available passage IDs: {list(passages.keys())}..." - ) - raise RuntimeError( - f"FATAL: Passage with ID {nid} not found" - ) - except Exception as e: - print( - f"ERROR: Exception looking up passage ID {nid}: {e}" - ) - raise - lookup_timer.print_elapsed() - - # Process embeddings in chunks if needed - all_node_embeddings = [] - total_size = len(texts) - - if total_size > max_batch_size: - for i in range(0, total_size, max_batch_size): - end_idx = min(i + max_batch_size, total_size) - chunk_texts = texts[i:end_idx] - chunk_ids = node_ids[i:end_idx] - - embeddings_chunk = process_batch( - chunk_texts, chunk_ids, missing_ids - ) - all_node_embeddings.append(embeddings_chunk) - - if cuda_available: - torch.cuda.empty_cache() - elif device.type == "mps": - torch.mps.empty_cache() - - node_embeddings = np.vstack(all_node_embeddings) - else: - node_embeddings = process_batch( - texts, node_ids, missing_ids - ) - - # Calculate distances - query_tensor = torch.tensor(query_vector, device=device).float() - node_embeddings_tensor = torch.tensor( - node_embeddings, device=device - ).float() - - calc_timer = DeviceTimer("distance calculation", device) - with calc_timer.timing(): - with torch.no_grad(): - if distance_metric == "l2": - node_embeddings_np = ( - node_embeddings_tensor.cpu() - .numpy() - .astype(np.float32) - ) - query_np = ( - query_tensor.cpu().numpy().astype(np.float32) - ) - distances = np.sum( - np.square( - node_embeddings_np - query_np.reshape(1, -1) - ), - axis=1, - ) - else: # mips or cosine - node_embeddings_np = ( - node_embeddings_tensor.cpu().numpy() - ) - query_np = query_tensor.cpu().numpy() - distances = -np.dot(node_embeddings_np, query_np) - calc_timer.print_elapsed() - - try: - response_payload = distances.flatten().tolist() - response_bytes = msgpack.packb( - [response_payload], use_single_float=True - ) - print( - f"Sending distance response with {len(distances)} distances" - ) - except Exception as pack_error: - print( - f"ERROR: Error packing MessagePack distance response: {pack_error}" - ) - print(f"ERROR: distances shape: {distances.shape}") - print(f"ERROR: distances dtype: {distances.dtype}") - print(f"ERROR: distances content: {distances}") - print(f"ERROR: node_ids: {node_ids}") - print(f"ERROR: query_vector shape: {query_vector.shape}") - # Still return empty for now but with full error info - response_bytes = msgpack.packb([[]]) - - socket.send(response_bytes) - - if device.type == "cuda": - torch.cuda.synchronize() - elif device.type == "mps": - torch.mps.synchronize() + response = embeddings.tolist() + socket.send(msgpack.packb(response)) e2e_end = time.time() logger.info( - f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s" + f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s" ) continue - - # Standard embedding request (passage ID lookup) - if ( - not isinstance(request_payload, list) - or len(request_payload) != 1 - or not isinstance(request_payload[0], list) - ): - print( - f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}" - ) - socket.send(msgpack.packb([[], []])) - continue - + # Handle distance calculation requests + if ( + isinstance(request_payload, list) + and len(request_payload) == 2 + and isinstance(request_payload[0], list) + and isinstance(request_payload[1], list) + ): node_ids = request_payload[0] - print(f"Request for {len(node_ids)} node embeddings") + query_vector = np.array(request_payload[1], dtype=np.float32) - except Exception as unpack_error: - print(f"Error unpacking MessagePack request: {unpack_error}") + logger.debug("Distance calculation request received") + print(f" Node IDs: {node_ids}") + print(f" Query vector dim: {len(query_vector)}") + + # Get embeddings for node IDs + texts = [] + for nid in node_ids: + try: + passage_data = passages.get_passage(str(nid)) + txt = passage_data["text"] + texts.append(txt) + except KeyError: + print(f"ERROR: Passage ID {nid} not found") + raise RuntimeError(f"FATAL: Passage with ID {nid} not found") + except Exception as e: + print(f"ERROR: Exception looking up passage ID {nid}: {e}") + raise + + # Process embeddings + embeddings = compute_embeddings( + texts, model_name, mode=embedding_mode + ) + print( + f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) + + # Calculate distances + if distance_metric == "l2": + distances = np.sum( + np.square(embeddings - query_vector.reshape(1, -1)), axis=1 + ) + else: # mips or cosine + distances = -np.dot(embeddings, query_vector) + + response_payload = distances.flatten().tolist() + response_bytes = msgpack.packb( + [response_payload], use_single_float=True + ) + print(f"Sending distance response with {len(distances)} distances") + + socket.send(response_bytes) + e2e_end = time.time() + logger.info( + f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s" + ) + continue + + # Standard embedding request (passage ID lookup) + if ( + not isinstance(request_payload, list) + or len(request_payload) != 1 + or not isinstance(request_payload[0], list) + ): + print( + f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}" + ) socket.send(msgpack.packb([[], []])) continue + node_ids = request_payload[0] + print(f"Request for {len(node_ids)} node embeddings") + # Look up texts by node IDs texts = [] - missing_ids = [] - with lookup_timer.timing(): - for nid in node_ids: - try: - txtinfo = passages[nid] - if txtinfo is None or txtinfo["text"] == "": - raise RuntimeError( - f"FATAL: Passage with ID {nid} not found - failing fast" - ) - else: - txt = txtinfo["text"] - except (KeyError, IndexError): - raise RuntimeError( - f"FATAL: Passage with ID {nid} not found - failing fast" - ) + for nid in node_ids: + try: + passage_data = passages.get_passage(str(nid)) + txt = passage_data["text"] + if not txt: + raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") texts.append(txt) - lookup_timer.print_elapsed() + except KeyError: + raise RuntimeError(f"FATAL: Passage with ID {nid} not found") + except Exception as e: + print(f"ERROR: Exception looking up passage ID {nid}: {e}") + raise - if missing_ids: - print(f"Missing passages for IDs: {missing_ids}") - - # Process in chunks - total_size = len(texts) + # Process embeddings + embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) print( - f"Total batch size: {total_size}, max_batch_size: {max_batch_size}" + f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" ) - all_embeddings = [] - - if total_size > max_batch_size: - print( - f"Splitting batch of size {total_size} into chunks of {max_batch_size}" - ) - for i in range(0, total_size, max_batch_size): - end_idx = min(i + max_batch_size, total_size) - print( - f"Processing chunk {i // max_batch_size + 1}/{(total_size + max_batch_size - 1) // max_batch_size}: items {i} to {end_idx - 1}" - ) - - chunk_texts = texts[i:end_idx] - chunk_ids = node_ids[i:end_idx] - - embeddings_chunk = process_batch( - chunk_texts, chunk_ids, missing_ids - ) - all_embeddings.append(embeddings_chunk) - - if cuda_available: - torch.cuda.empty_cache() - elif device.type == "mps": - torch.mps.empty_cache() - - hidden = np.vstack(all_embeddings) - print(f"Combined embeddings shape: {hidden.shape}") - else: - hidden = process_batch(texts, node_ids, missing_ids) - # Serialization and response - ser_start = time.time() - - print( - f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}" - ) - if np.isnan(hidden).any() or np.isinf(hidden).any(): + if np.isnan(embeddings).any() or np.isinf(embeddings).any(): print( - f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! " - f"Requested IDs (sample): {node_ids[:5]}...{RESET}" + f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}" ) assert False - try: - hidden_contiguous_f32 = np.ascontiguousarray( - hidden, dtype=np.float32 - ) - response_payload = [ - list(hidden_contiguous_f32.shape), - hidden_contiguous_f32.flatten().tolist(), - ] - response_bytes = msgpack.packb( - response_payload, use_single_float=True - ) - except Exception as pack_error: - print(f"Error packing MessagePack response: {pack_error}") - response_bytes = msgpack.packb([[], []]) + hidden_contiguous_f32 = np.ascontiguousarray( + embeddings, dtype=np.float32 + ) + response_payload = [ + list(hidden_contiguous_f32.shape), + hidden_contiguous_f32.flatten().tolist(), + ] + response_bytes = msgpack.packb(response_payload, use_single_float=True) socket.send(response_bytes) - ser_end = time.time() - - print(f"Serialize time: {ser_end - ser_start:.6f} seconds") - - if device.type == "cuda": - torch.cuda.synchronize() - elif device.type == "mps": - torch.mps.synchronize() e2e_end = time.time() logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") @@ -1093,19 +246,7 @@ def create_hnsw_embedding_server( import traceback traceback.print_exc() - try: - socket.send(msgpack.packb([[], []])) - except: - pass - - # Start warmup and server threads - if enable_warmup and len(passages) > 0: - print(f"Warmup enabled: starting warmup thread") - warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,)) - warmup_thread.daemon = True - warmup_thread.start() - else: - print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})") + socket.send(msgpack.packb([[], []])) zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) zmq_thread.start() @@ -1128,74 +269,30 @@ if __name__ == "__main__": type=str, help="JSON file containing passage ID to text mapping", ) - parser.add_argument( - "--embeddings-file", - type=str, - help="Pickle file containing pre-computed embeddings", - ) - parser.add_argument("--use-fp16", action="store_true", default=False) - parser.add_argument("--use-int8", action="store_true", default=False) - parser.add_argument("--use-cuda-graphs", action="store_true", default=False) - parser.add_argument( - "--max-batch-size", - type=int, - default=128, - help="Maximum batch size before splitting", - ) parser.add_argument( "--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2", help="Embedding model name", ) - parser.add_argument( - "--custom-max-length", - type=int, - default=None, - help="Override model's default max sequence length", - ) parser.add_argument( "--distance-metric", type=str, default="mips", help="Distance metric to use" ) parser.add_argument( - "--embedding-mode", - type=str, - default="sentence-transformers", - choices=["sentence-transformers", "mlx", "openai"], - help="Embedding backend mode" - ) - parser.add_argument( - "--use-mlx", - action="store_true", - default=False, - help="Use MLX for model inference (deprecated: use --embedding-mode mlx)", - ) - parser.add_argument( - "--disable-warmup", - action="store_true", - default=False, - help="Disable warmup requests on server start", + "--embedding-mode", + type=str, + default="sentence-transformers", + choices=["sentence-transformers", "openai"], + help="Embedding backend mode", ) args = parser.parse_args() - - # Handle backward compatibility with use_mlx - embedding_mode = args.embedding_mode - if args.use_mlx: - embedding_mode = "mlx" # Create and start the HNSW embedding server create_hnsw_embedding_server( passages_file=args.passages_file, - embeddings_file=args.embeddings_file, - use_fp16=args.use_fp16, - use_int8=args.use_int8, - use_cuda_graphs=args.use_cuda_graphs, zmq_port=args.zmq_port, - max_batch_size=args.max_batch_size, model_name=args.model_name, - custom_max_length_param=args.custom_max_length, distance_metric=args.distance_metric, - embedding_mode=embedding_mode, - enable_warmup=not args.disable_warmup, + embedding_mode=args.embedding_mode, ) diff --git a/packages/leann-backend-hnsw/pyproject.toml b/packages/leann-backend-hnsw/pyproject.toml index 12df4d6..274f2b4 100644 --- a/packages/leann-backend-hnsw/pyproject.toml +++ b/packages/leann-backend-hnsw/pyproject.toml @@ -15,4 +15,8 @@ wheel.packages = ["leann_backend_hnsw"] editable.mode = "redirect" cmake.build-type = "Release" build.verbose = true -build.tool-args = ["-j8"] \ No newline at end of file +build.tool-args = ["-j8"] + +# CMake definitions to optimize compilation +[tool.scikit-build.cmake.define] +CMAKE_BUILD_PARALLEL_LEVEL = "8" \ No newline at end of file diff --git a/packages/leann-backend-hnsw/third_party/faiss b/packages/leann-backend-hnsw/third_party/faiss index 2547df4..ff22e2c 160000 --- a/packages/leann-backend-hnsw/third_party/faiss +++ b/packages/leann-backend-hnsw/third_party/faiss @@ -1 +1 @@ -Subproject commit 2547df4377ae097e2eabc9b019c15135b1fea2b4 +Subproject commit ff22e2c86be1784c760265abe146b1ab0db90ebe diff --git a/packages/leann-core/pyproject.toml b/packages/leann-core/pyproject.toml index 7f64793..08d2b4e 100644 --- a/packages/leann-core/pyproject.toml +++ b/packages/leann-core/pyproject.toml @@ -15,5 +15,8 @@ dependencies = [ "tqdm>=4.60.0" ] +[project.scripts] +leann = "leann.cli:main" + [tool.setuptools.packages.find] where = ["src"] \ No newline at end of file diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 832093d..d5f3a53 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -9,9 +9,6 @@ import numpy as np from pathlib import Path from typing import List, Dict, Any, Optional, Literal from dataclasses import dataclass, field -import uuid -import torch - from .registry import BACKEND_REGISTRY from .interface import LeannBackendFactoryInterface from .chat import get_llm @@ -22,7 +19,7 @@ def compute_embeddings( model_name: str, mode: str = "sentence-transformers", use_server: bool = True, - use_mlx: bool = False # Backward compatibility: if True, override mode to 'mlx', + port: int = 5557, ) -> np.ndarray: """ Computes embeddings using different backends. @@ -39,251 +36,60 @@ def compute_embeddings( Returns: numpy array of embeddings """ - # Override mode for backward compatibility - if use_mlx: - mode = "mlx" - - # Auto-detect mode based on model name if not explicitly set - if mode == "sentence-transformers" and model_name.startswith("text-embedding-"): - mode = "openai" - - if mode == "mlx": - return compute_embeddings_mlx(chunks, model_name, batch_size=16) - elif mode == "openai": - return compute_embeddings_openai(chunks, model_name) - elif mode == "sentence-transformers": - return compute_embeddings_sentence_transformers( - chunks, model_name, use_server=use_server - ) + if use_server: + # Use embedding server (for search/query) + return compute_embeddings_via_server(chunks, model_name, port=port) else: - raise ValueError( - f"Unsupported embedding mode: {mode}. Supported modes: sentence-transformers, mlx, openai" + # Use direct computation (for build_index) + from .embedding_compute import ( + compute_embeddings as compute_embeddings_direct, + ) + + return compute_embeddings_direct( + chunks, + model_name, + mode=mode, ) -def compute_embeddings_sentence_transformers( - chunks: List[str], model_name: str, use_server: bool = True +def compute_embeddings_via_server( + chunks: List[str], model_name: str, port: int ) -> np.ndarray: """Computes embeddings using sentence-transformers. Args: chunks: List of text chunks to embed model_name: Name of the sentence transformer model - use_server: If True, use embedding server (good for search). If False, use direct computation (good for build). """ - if not use_server: - print( - f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..." - ) - return _compute_embeddings_sentence_transformers_direct(chunks, model_name) - print( f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (via embedding server)..." ) + import zmq + import msgpack + import numpy as np - # Use embedding server for sentence-transformers too - # This avoids loading the model twice (once in API, once in server) - try: - # Import ZMQ client functionality and server manager - import zmq - import msgpack - import numpy as np - from .embedding_server_manager import EmbeddingServerManager + # Connect to embedding server + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect(f"tcp://localhost:{port}") - # Ensure embedding server is running - port = 5557 - server_manager = EmbeddingServerManager( - backend_module_name="leann_backend_hnsw.hnsw_embedding_server" - ) + # Send chunks to server for embedding computation + request = chunks + socket.send(msgpack.packb(request)) - server_started = server_manager.start_server( - port=port, - model_name=model_name, - embedding_mode="sentence-transformers", - enable_warmup=False, - ) + # Receive embeddings from server + response = socket.recv() + embeddings_list = msgpack.unpackb(response) - if not server_started: - raise RuntimeError(f"Failed to start embedding server on port {port}") + # Convert back to numpy array + embeddings = np.array(embeddings_list, dtype=np.float32) - # Connect to embedding server - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.connect(f"tcp://localhost:{port}") - - # Send chunks to server for embedding computation - request = chunks - socket.send(msgpack.packb(request)) - - # Receive embeddings from server - response = socket.recv() - embeddings_list = msgpack.unpackb(response) - - # Convert back to numpy array - embeddings = np.array(embeddings_list, dtype=np.float32) - - socket.close() - context.term() - - return embeddings - - except Exception as e: - # Fallback to direct sentence-transformers if server connection fails - print( - f"Warning: Failed to connect to embedding server, falling back to direct computation: {e}" - ) - return _compute_embeddings_sentence_transformers_direct(chunks, model_name) - - -def _compute_embeddings_sentence_transformers_direct( - chunks: List[str], model_name: str -) -> np.ndarray: - """Direct sentence-transformers computation (fallback).""" - try: - from sentence_transformers import SentenceTransformer - except ImportError as e: - raise RuntimeError( - "sentence-transformers not available. Install with: uv pip install sentence-transformers" - ) from e - - # Load model using sentence-transformers - model = SentenceTransformer(model_name) - - model = model.half() - print( - f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}' (direct)..." - ) - # use acclerater GPU or MAC GPU - - if torch.cuda.is_available(): - model = model.to("cuda") - elif torch.backends.mps.is_available(): - model = model.to("mps") - - # Generate embeddings - # give use an warning if OOM here means we need to turn down the batch size - embeddings = model.encode( - chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=16 - ) + socket.close() + context.term() return embeddings -def compute_embeddings_openai(chunks: List[str], model_name: str) -> np.ndarray: - """Computes embeddings using OpenAI API.""" - try: - import openai - import os - except ImportError as e: - raise RuntimeError( - "openai not available. Install with: uv pip install openai" - ) from e - - # Get API key from environment - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise RuntimeError("OPENAI_API_KEY environment variable not set") - - client = openai.OpenAI(api_key=api_key) - - print( - f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'..." - ) - - # OpenAI has a limit on batch size and input length - max_batch_size = 100 # Conservative batch size - all_embeddings = [] - - try: - from tqdm import tqdm - total_batches = (len(chunks) + max_batch_size - 1) // max_batch_size - batch_range = range(0, len(chunks), max_batch_size) - batch_iterator = tqdm(batch_range, desc="Computing embeddings", unit="batch", total=total_batches) - except ImportError: - # Fallback without progress bar - batch_iterator = range(0, len(chunks), max_batch_size) - - for i in batch_iterator: - batch_chunks = chunks[i:i + max_batch_size] - - try: - response = client.embeddings.create(model=model_name, input=batch_chunks) - batch_embeddings = [embedding.embedding for embedding in response.data] - all_embeddings.extend(batch_embeddings) - except Exception as e: - print(f"ERROR: Failed to get embeddings for batch starting at {i}: {e}") - raise - - embeddings = np.array(all_embeddings, dtype=np.float32) - print( - f"INFO: Generated {len(embeddings)} embeddings with dimension {embeddings.shape[1]}" - ) - return embeddings - - -def compute_embeddings_mlx(chunks: List[str], model_name: str, batch_size: int = 16) -> np.ndarray: - """Computes embeddings using an MLX model.""" - try: - import mlx.core as mx - from mlx_lm.utils import load - from tqdm import tqdm - except ImportError as e: - raise RuntimeError( - "MLX or related libraries not available. Install with: uv pip install mlx mlx-lm" - ) from e - - print( - f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..." - ) - - # Load model and tokenizer - model, tokenizer = load(model_name) - - # Process chunks in batches with progress bar - all_embeddings = [] - - try: - from tqdm import tqdm - batch_iterator = tqdm(range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch") - except ImportError: - batch_iterator = range(0, len(chunks), batch_size) - - for i in batch_iterator: - batch_chunks = chunks[i:i + batch_size] - - # Tokenize all chunks in the batch - batch_token_ids = [] - for chunk in batch_chunks: - token_ids = tokenizer.encode(chunk) # type: ignore - batch_token_ids.append(token_ids) - - # Pad sequences to the same length for batch processing - max_length = max(len(ids) for ids in batch_token_ids) - padded_token_ids = [] - for token_ids in batch_token_ids: - # Pad with tokenizer.pad_token_id or 0 - padded = token_ids + [0] * (max_length - len(token_ids)) - padded_token_ids.append(padded) - - # Convert to MLX array with batch dimension - input_ids = mx.array(padded_token_ids) - - # Get embeddings for the batch - embeddings = model(input_ids) - - # Mean pooling for each sequence in the batch - pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size) - - # Convert batch embeddings to numpy - for j in range(len(batch_chunks)): - pooled_list = pooled[j].tolist() # Convert to list - pooled_numpy = np.array(pooled_list, dtype=np.float32) - all_embeddings.append(pooled_numpy) - - # Stack numpy arrays - return np.stack(all_embeddings) - - @dataclass class SearchResult: id: str @@ -344,14 +150,12 @@ class LeannBuilder: self.dimensions = dimensions self.embedding_mode = embedding_mode self.backend_kwargs = backend_kwargs - if 'mlx' in self.embedding_model: - self.embedding_mode = "mlx" self.chunks: List[Dict[str, Any]] = [] def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): if metadata is None: metadata = {} - passage_id = metadata.get("id", str(uuid.uuid4())) + passage_id = metadata.get("id", str(len(self.chunks))) chunk_data = {"id": passage_id, "text": text, "metadata": metadata} self.chunks.append(chunk_data) @@ -377,10 +181,13 @@ class LeannBuilder: with open(passages_file, "w", encoding="utf-8") as f: try: from tqdm import tqdm - chunk_iterator = tqdm(self.chunks, desc="Writing passages", unit="chunk") + + chunk_iterator = tqdm( + self.chunks, desc="Writing passages", unit="chunk" + ) except ImportError: chunk_iterator = self.chunks - + for chunk in chunk_iterator: offset = f.tell() json.dump( @@ -398,7 +205,11 @@ class LeannBuilder: pickle.dump(offset_map, f) texts_to_embed = [c["text"] for c in self.chunks] embeddings = compute_embeddings( - texts_to_embed, self.embedding_model, self.embedding_mode, use_server=False + texts_to_embed, + self.embedding_model, + self.embedding_mode, + use_server=False, + port=5557, ) string_ids = [chunk["id"] for chunk in self.chunks] current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} diff --git a/packages/leann-core/src/leann/cli.py b/packages/leann-core/src/leann/cli.py new file mode 100644 index 0000000..854265b --- /dev/null +++ b/packages/leann-core/src/leann/cli.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +import argparse +import asyncio +import sys +from pathlib import Path +from typing import Optional +import os + +from llama_index.core import SimpleDirectoryReader +from llama_index.core.node_parser import SentenceSplitter + +from .api import LeannBuilder, LeannSearcher, LeannChat + + +class LeannCLI: + def __init__(self): + self.indexes_dir = Path.home() / ".leann" / "indexes" + self.indexes_dir.mkdir(parents=True, exist_ok=True) + + self.node_parser = SentenceSplitter( + chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n" + ) + + def get_index_path(self, index_name: str) -> str: + index_dir = self.indexes_dir / index_name + return str(index_dir / "documents.leann") + + def index_exists(self, index_name: str) -> bool: + index_dir = self.indexes_dir / index_name + meta_file = index_dir / "documents.leann.meta.json" + return meta_file.exists() + + def create_parser(self) -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="leann", + description="LEANN - Local Enhanced AI Navigation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + leann build my-docs --docs ./documents # Build index named my-docs + leann search my-docs "query" # Search in my-docs index + leann ask my-docs "question" # Ask my-docs index + leann list # List all stored indexes + """ + ) + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # Build command + build_parser = subparsers.add_parser("build", help="Build document index") + build_parser.add_argument("index_name", help="Index name") + build_parser.add_argument("--docs", type=str, required=True, help="Documents directory") + build_parser.add_argument("--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]) + build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever") + build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild") + build_parser.add_argument("--graph-degree", type=int, default=32) + build_parser.add_argument("--complexity", type=int, default=64) + build_parser.add_argument("--num-threads", type=int, default=1) + build_parser.add_argument("--compact", action="store_true", default=True) + build_parser.add_argument("--recompute", action="store_true", default=True) + + # Search command + search_parser = subparsers.add_parser("search", help="Search documents") + search_parser.add_argument("index_name", help="Index name") + search_parser.add_argument("query", help="Search query") + search_parser.add_argument("--top-k", type=int, default=5) + search_parser.add_argument("--complexity", type=int, default=64) + search_parser.add_argument("--beam-width", type=int, default=1) + search_parser.add_argument("--prune-ratio", type=float, default=0.0) + search_parser.add_argument("--recompute-embeddings", action="store_true") + search_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global") + + # Ask command + ask_parser = subparsers.add_parser("ask", help="Ask questions") + ask_parser.add_argument("index_name", help="Index name") + ask_parser.add_argument("--llm", type=str, default="ollama", choices=["simulated", "ollama", "hf", "openai"]) + ask_parser.add_argument("--model", type=str, default="qwen3:8b") + ask_parser.add_argument("--host", type=str, default="http://localhost:11434") + ask_parser.add_argument("--interactive", "-i", action="store_true") + ask_parser.add_argument("--top-k", type=int, default=20) + ask_parser.add_argument("--complexity", type=int, default=32) + ask_parser.add_argument("--beam-width", type=int, default=1) + ask_parser.add_argument("--prune-ratio", type=float, default=0.0) + ask_parser.add_argument("--recompute-embeddings", action="store_true") + ask_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global") + + # List command + list_parser = subparsers.add_parser("list", help="List all indexes") + + return parser + + def list_indexes(self): + print("Stored LEANN indexes:") + + if not self.indexes_dir.exists(): + print("No indexes found. Use 'leann build --docs ' to create one.") + return + + index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()] + + if not index_dirs: + print("No indexes found. Use 'leann build --docs ' to create one.") + return + + print(f"Found {len(index_dirs)} indexes:") + for i, index_dir in enumerate(index_dirs, 1): + index_name = index_dir.name + status = "✓" if self.index_exists(index_name) else "✗" + + print(f" {i}. {index_name} [{status}]") + if self.index_exists(index_name): + meta_file = index_dir / "documents.leann.meta.json" + size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (1024 * 1024) + print(f" Size: {size_mb:.1f} MB") + + if index_dirs: + example_name = index_dirs[0].name + print(f"\nUsage:") + print(f" leann search {example_name} \"your query\"") + print(f" leann ask {example_name} --interactive") + + def load_documents(self, docs_dir: str): + print(f"Loading documents from {docs_dir}...") + + documents = SimpleDirectoryReader( + docs_dir, + recursive=True, + encoding="utf-8", + required_exts=[".pdf", ".txt", ".md", ".docx"], + ).load_data(show_progress=True) + + all_texts = [] + for doc in documents: + nodes = self.node_parser.get_nodes_from_documents([doc]) + for node in nodes: + all_texts.append(node.get_content()) + + print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks") + return all_texts + + async def build_index(self, args): + docs_dir = args.docs + index_name = args.index_name + index_dir = self.indexes_dir / index_name + index_path = self.get_index_path(index_name) + + if index_dir.exists() and not args.force: + print(f"Index '{index_name}' already exists. Use --force to rebuild.") + return + + all_texts = self.load_documents(docs_dir) + if not all_texts: + print("No documents found") + return + + index_dir.mkdir(parents=True, exist_ok=True) + + print(f"Building index '{index_name}' with {args.backend} backend...") + + builder = LeannBuilder( + backend_name=args.backend, + embedding_model=args.embedding_model, + graph_degree=args.graph_degree, + complexity=args.complexity, + is_compact=args.compact, + is_recompute=args.recompute, + num_threads=args.num_threads, + ) + + for chunk_text in all_texts: + builder.add_text(chunk_text) + + builder.build_index(index_path) + print(f"Index built at {index_path}") + + async def search_documents(self, args): + index_name = args.index_name + query = args.query + index_path = self.get_index_path(index_name) + + if not self.index_exists(index_name): + print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs ' to create it.") + return + + searcher = LeannSearcher(index_path=index_path) + results = searcher.search( + query, + top_k=args.top_k, + complexity=args.complexity, + beam_width=args.beam_width, + prune_ratio=args.prune_ratio, + recompute_embeddings=args.recompute_embeddings, + pruning_strategy=args.pruning_strategy + ) + + print(f"Search results for '{query}' (top {len(results)}):") + for i, result in enumerate(results, 1): + print(f"{i}. Score: {result.score:.3f}") + print(f" {result.text[:200]}...") + print() + + async def ask_questions(self, args): + index_name = args.index_name + index_path = self.get_index_path(index_name) + + if not self.index_exists(index_name): + print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs ' to create it.") + return + + print(f"Starting chat with index '{index_name}'...") + print(f"Using {args.model} ({args.llm})") + + llm_config = {"type": args.llm, "model": args.model} + if args.llm == "ollama": + llm_config["host"] = args.host + + chat = LeannChat(index_path=index_path, llm_config=llm_config) + + if args.interactive: + print("LEANN Assistant ready! Type 'quit' to exit") + print("=" * 40) + + while True: + user_input = input("\nYou: ").strip() + if user_input.lower() in ['quit', 'exit', 'q']: + print("Goodbye!") + break + + if not user_input: + continue + + response = chat.ask( + user_input, + top_k=args.top_k, + complexity=args.complexity, + beam_width=args.beam_width, + prune_ratio=args.prune_ratio, + recompute_embeddings=args.recompute_embeddings, + pruning_strategy=args.pruning_strategy + ) + print(f"LEANN: {response}") + else: + query = input("Enter your question: ").strip() + if query: + response = chat.ask( + query, + top_k=args.top_k, + complexity=args.complexity, + beam_width=args.beam_width, + prune_ratio=args.prune_ratio, + recompute_embeddings=args.recompute_embeddings, + pruning_strategy=args.pruning_strategy + ) + print(f"LEANN: {response}") + + async def run(self, args=None): + parser = self.create_parser() + + if args is None: + args = parser.parse_args() + + if not args.command: + parser.print_help() + return + + if args.command == "list": + self.list_indexes() + elif args.command == "build": + await self.build_index(args) + elif args.command == "search": + await self.search_documents(args) + elif args.command == "ask": + await self.ask_questions(args) + else: + parser.print_help() + + +def main(): + import dotenv + dotenv.load_dotenv() + + cli = LeannCLI() + asyncio.run(cli.run()) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py new file mode 100644 index 0000000..20cef9f --- /dev/null +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -0,0 +1,272 @@ +""" +Unified embedding computation module +Consolidates all embedding computation logic using SentenceTransformer +Preserves all optimization parameters to ensure performance +""" + +import numpy as np +import torch +from typing import List +import logging + +logger = logging.getLogger(__name__) + + +def compute_embeddings( + texts: List[str], model_name: str, mode: str = "sentence-transformers" +) -> np.ndarray: + """ + Unified embedding computation entry point + + Args: + texts: List of texts to compute embeddings for + model_name: Model name + mode: Computation mode ('sentence-transformers', 'openai', 'mlx') + + Returns: + Normalized embeddings array, shape: (len(texts), embedding_dim) + """ + if mode == "sentence-transformers": + return compute_embeddings_sentence_transformers(texts, model_name) + elif mode == "openai": + return compute_embeddings_openai(texts, model_name) + elif mode == "mlx": + return compute_embeddings_mlx(texts, model_name) + else: + raise ValueError(f"Unsupported embedding mode: {mode}") + + +def compute_embeddings_sentence_transformers( + texts: List[str], + model_name: str, + use_fp16: bool = True, + device: str = "auto", + batch_size: int = 32, +) -> np.ndarray: + """ + Compute embeddings using SentenceTransformer + Preserves all optimization parameters to ensure consistency with original embedding_server + + Args: + texts: List of texts to compute embeddings for + model_name: SentenceTransformer model name + use_fp16: Whether to use FP16 precision + device: Device selection ('auto', 'cuda', 'mps', 'cpu') + batch_size: Batch size for processing + + Returns: + Normalized embeddings array, shape: (len(texts), embedding_dim) + """ + print( + f"INFO: Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'" + ) + + from sentence_transformers import SentenceTransformer + + # Auto-detect device + if device == "auto": + if torch.cuda.is_available(): + device = "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + print(f"INFO: Using device: {device}") + + # Prepare model and tokenizer optimization parameters (consistent with original embedding_server) + model_kwargs = { + "torch_dtype": torch.float16 if use_fp16 else torch.float32, + "low_cpu_mem_usage": True, + "_fast_init": True, # Skip weight initialization checks for faster loading + } + + tokenizer_kwargs = { + "use_fast": True, # Use fast tokenizer for better runtime performance + } + + # Load SentenceTransformer (try local first, then network) + print(f"INFO: Loading SentenceTransformer model: {model_name}") + + try: + # Try local loading (avoid network delays) + model_kwargs["local_files_only"] = True + tokenizer_kwargs["local_files_only"] = True + + model = SentenceTransformer( + model_name, + device=device, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + local_files_only=True, + ) + print("✅ Model loaded successfully! (local + optimized)") + except Exception as e: + print(f"Local loading failed ({e}), trying network download...") + # Fallback to network loading + model_kwargs["local_files_only"] = False + tokenizer_kwargs["local_files_only"] = False + + model = SentenceTransformer( + model_name, + device=device, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + local_files_only=False, + ) + print("✅ Model loaded successfully! (network + optimized)") + + # Apply additional optimizations (if supported) + if use_fp16 and device in ["cuda", "mps"]: + try: + model = model.half() + model = torch.compile(model) + print(f"✅ Using FP16 precision and compile optimization: {model_name}") + except Exception as e: + print( + f"FP16 or compile optimization failed, continuing with default settings: {e}" + ) + + # Compute embeddings (using SentenceTransformer's optimized implementation) + print("INFO: Starting embedding computation...") + + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=False, # Don't show progress bar in server environment + convert_to_numpy=True, + normalize_embeddings=False, # Keep consistent with original API behavior + device=device, + ) + + print( + f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}" + ) + + # Validate results + if np.isnan(embeddings).any() or np.isinf(embeddings).any(): + raise RuntimeError( + f"Detected NaN or Inf values in embeddings, model: {model_name}" + ) + + return embeddings + + +def compute_embeddings_openai(texts: List[str], model_name: str) -> np.ndarray: + """Compute embeddings using OpenAI API""" + try: + import openai + import os + except ImportError as e: + raise ImportError(f"OpenAI package not installed: {e}") + + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise RuntimeError("OPENAI_API_KEY environment variable not set") + + client = openai.OpenAI(api_key=api_key) + + print( + f"INFO: Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'" + ) + + # OpenAI has limits on batch size and input length + max_batch_size = 100 # Conservative batch size + all_embeddings = [] + + try: + from tqdm import tqdm + + total_batches = (len(texts) + max_batch_size - 1) // max_batch_size + batch_range = range(0, len(texts), max_batch_size) + batch_iterator = tqdm( + batch_range, desc="Computing embeddings", unit="batch", total=total_batches + ) + except ImportError: + # Fallback when tqdm is not available + batch_iterator = range(0, len(texts), max_batch_size) + + for i in batch_iterator: + batch_texts = texts[i : i + max_batch_size] + + try: + response = client.embeddings.create(model=model_name, input=batch_texts) + batch_embeddings = [embedding.embedding for embedding in response.data] + all_embeddings.extend(batch_embeddings) + except Exception as e: + print(f"ERROR: Batch {i} failed: {e}") + raise + + embeddings = np.array(all_embeddings, dtype=np.float32) + print( + f"INFO: Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}" + ) + return embeddings + + +def compute_embeddings_mlx( + chunks: List[str], model_name: str, batch_size: int = 16 +) -> np.ndarray: + """Computes embeddings using an MLX model.""" + try: + import mlx.core as mx + from mlx_lm.utils import load + from tqdm import tqdm + except ImportError as e: + raise RuntimeError( + "MLX or related libraries not available. Install with: uv pip install mlx mlx-lm" + ) from e + + print( + f"INFO: Computing embeddings for {len(chunks)} chunks using MLX model '{model_name}' with batch_size={batch_size}..." + ) + + # Load model and tokenizer + model, tokenizer = load(model_name) + + # Process chunks in batches with progress bar + all_embeddings = [] + + try: + from tqdm import tqdm + + batch_iterator = tqdm( + range(0, len(chunks), batch_size), desc="Computing embeddings", unit="batch" + ) + except ImportError: + batch_iterator = range(0, len(chunks), batch_size) + + for i in batch_iterator: + batch_chunks = chunks[i : i + batch_size] + + # Tokenize all chunks in the batch + batch_token_ids = [] + for chunk in batch_chunks: + token_ids = tokenizer.encode(chunk) # type: ignore + batch_token_ids.append(token_ids) + + # Pad sequences to the same length for batch processing + max_length = max(len(ids) for ids in batch_token_ids) + padded_token_ids = [] + for token_ids in batch_token_ids: + # Pad with tokenizer.pad_token_id or 0 + padded = token_ids + [0] * (max_length - len(token_ids)) + padded_token_ids.append(padded) + + # Convert to MLX array with batch dimension + input_ids = mx.array(padded_token_ids) + + # Get embeddings for the batch + embeddings = model(input_ids) + + # Mean pooling for each sequence in the batch + pooled = embeddings.mean(axis=1) # Shape: (batch_size, hidden_size) + + # Convert batch embeddings to numpy + for j in range(len(batch_chunks)): + pooled_list = pooled[j].tolist() # Convert to list + pooled_numpy = np.array(pooled_list, dtype=np.float32) + all_embeddings.append(pooled_numpy) + + # Stack numpy arrays + return np.stack(all_embeddings) diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 2022262..6a44160 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -4,11 +4,10 @@ import atexit import socket import subprocess import sys -import zmq -import msgpack from pathlib import Path from typing import Optional import select +import psutil def _check_port(port: int) -> bool: @@ -17,151 +16,135 @@ def _check_port(port: int) -> bool: return s.connect_ex(("localhost", port)) == 0 -def _check_server_meta_path(port: int, expected_meta_path: str) -> bool: +def _check_process_matches_config( + port: int, expected_model: str, expected_passages_file: str +) -> bool: """ - Check if the existing server on the port is using the correct meta file. - Returns True if the server has the right meta path, False otherwise. + Check if the process using the port matches our expected model and passages file. + Returns True if matches, False otherwise. """ try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout - socket.connect(f"tcp://localhost:{port}") + for proc in psutil.process_iter(["pid", "cmdline"]): + if not _is_process_listening_on_port(proc, port): + continue - # Send a special control message to query the server's meta path - control_request = ["__QUERY_META_PATH__"] - request_bytes = msgpack.packb(control_request) - socket.send(request_bytes) + cmdline = proc.info["cmdline"] + if not cmdline: + continue - # Wait for response - response_bytes = socket.recv() - response = msgpack.unpackb(response_bytes) - - socket.close() - context.term() - - # Check if the response contains the meta path and if it matches - if isinstance(response, list) and len(response) > 0: - server_meta_path = response[0] - # Normalize paths for comparison - expected_path = Path(expected_meta_path).resolve() - server_path = Path(server_meta_path).resolve() if server_meta_path else None - return server_path == expected_path + return _check_cmdline_matches_config( + cmdline, port, expected_model, expected_passages_file + ) + print(f"DEBUG: No process found listening on port {port}") return False except Exception as e: - print(f"WARNING: Could not query server meta path on port {port}: {e}") + print(f"WARNING: Could not check process on port {port}: {e}") return False -def _update_server_meta_path(port: int, new_meta_path: str) -> bool: - """ - Send a control message to update the server's meta path. - Returns True if successful, False otherwise. - """ +def _is_process_listening_on_port(proc, port: int) -> bool: + """Check if a process is listening on the given port.""" try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout - socket.connect(f"tcp://localhost:{port}") - - # Send a control message to update the meta path - control_request = ["__UPDATE_META_PATH__", new_meta_path] - request_bytes = msgpack.packb(control_request) - socket.send(request_bytes) - - # Wait for response - response_bytes = socket.recv() - response = msgpack.unpackb(response_bytes) - - socket.close() - context.term() - - # Check if the update was successful - if isinstance(response, list) and len(response) > 0: - return response[0] == "SUCCESS" - + connections = proc.net_connections() + for conn in connections: + if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN: + return True return False - - except Exception as e: - print(f"ERROR: Could not update server meta path on port {port}: {e}") + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): return False -def _check_server_model(port: int, expected_model: str) -> bool: +def _check_cmdline_matches_config( + cmdline: list, port: int, expected_model: str, expected_passages_file: str +) -> bool: + """Check if command line matches our expected configuration.""" + cmdline_str = " ".join(cmdline) + print(f"DEBUG: Found process on port {port}: {cmdline_str}") + + # Check if it's our embedding server + is_embedding_server = any( + server_type in cmdline_str + for server_type in [ + "embedding_server", + "leann_backend_diskann.embedding_server", + "leann_backend_hnsw.hnsw_embedding_server", + ] + ) + + if not is_embedding_server: + print(f"DEBUG: Process on port {port} is not our embedding server") + return False + + # Check model name + model_matches = _check_model_in_cmdline(cmdline, expected_model) + + # Check passages file if provided + passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file) + + result = model_matches and passages_matches + print( + f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}" + ) + return result + + +def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool: + """Check if the command line contains the expected model.""" + if "--model-name" not in cmdline: + return False + + model_idx = cmdline.index("--model-name") + if model_idx + 1 >= len(cmdline): + return False + + actual_model = cmdline[model_idx + 1] + return actual_model == expected_model + + +def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str) -> bool: + """Check if the command line contains the expected passages file.""" + if "--passages-file" not in cmdline: + return False # Expected but not found + + passages_idx = cmdline.index("--passages-file") + if passages_idx + 1 >= len(cmdline): + return False + + actual_passages = cmdline[passages_idx + 1] + expected_path = Path(expected_passages_file).resolve() + actual_path = Path(actual_passages).resolve() + return actual_path == expected_path + + +def _find_compatible_port_or_next_available( + start_port: int, model_name: str, passages_file: str, max_attempts: int = 100 +) -> tuple[int, bool]: """ - Check if the existing server on the port is using the correct embedding model. - Returns True if the server has the right model, False otherwise. + Find a port that either has a compatible server or is available. + Returns (port, is_compatible) where is_compatible indicates if we found a matching server. """ - try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout - socket.connect(f"tcp://localhost:{port}") + for port in range(start_port, start_port + max_attempts): + if not _check_port(port): + # Port is available + return port, False - # Send a special control message to query the server's model - control_request = ["__QUERY_MODEL__"] - request_bytes = msgpack.packb(control_request) - socket.send(request_bytes) + # Port is in use, check if it's compatible + if _check_process_matches_config(port, model_name, passages_file): + print(f"✅ Found compatible server on port {port}") + return port, True + else: + print(f"⚠️ Port {port} has incompatible server, trying next port...") - # Wait for response - response_bytes = socket.recv() - response = msgpack.unpackb(response_bytes) - - socket.close() - context.term() - - # Check if the response contains the model name and if it matches - if isinstance(response, list) and len(response) > 0: - server_model = response[0] - return server_model == expected_model - - return False - - except Exception as e: - print(f"WARNING: Could not query server model on port {port}: {e}") - return False - - -def _update_server_model(port: int, new_model: str) -> bool: - """ - Send a control message to update the server's embedding model. - Returns True if successful, False otherwise. - """ - try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading - socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending - socket.connect(f"tcp://localhost:{port}") - - # Send a control message to update the model - control_request = ["__UPDATE_MODEL__", new_model] - request_bytes = msgpack.packb(control_request) - socket.send(request_bytes) - - # Wait for response - response_bytes = socket.recv() - response = msgpack.unpackb(response_bytes) - - socket.close() - context.term() - - # Check if the update was successful - if isinstance(response, list) and len(response) > 0: - return response[0] == "SUCCESS" - - return False - - except Exception as e: - print(f"ERROR: Could not update server model on port {port}: {e}") - return False + raise RuntimeError( + f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}" + ) class EmbeddingServerManager: """ - A generic manager for handling the lifecycle of a backend-specific embedding server process. + A simplified manager for embedding server processes that avoids complex update mechanisms. """ def __init__(self, backend_module_name: str): @@ -175,210 +158,162 @@ class EmbeddingServerManager: self.backend_module_name = backend_module_name self.server_process: Optional[subprocess.Popen] = None self.server_port: Optional[int] = None - atexit.register(self.stop_server) + self._atexit_registered = False - def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool: + def start_server( + self, + port: int, + model_name: str, + embedding_mode: str = "sentence-transformers", + **kwargs, + ) -> tuple[bool, int]: """ Starts the embedding server process. Args: - port (int): The ZMQ port for the server. + port (int): The preferred ZMQ port for the server. model_name (str): The name of the embedding model to use. - **kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup). + **kwargs: Additional arguments for the server. Returns: - bool: True if the server is started successfully or already running, False otherwise. + tuple[bool, int]: (success, actual_port_used) """ - if self.server_process and self.server_process.poll() is None: - # Even if we have a running process, check if model/meta path match - if self.server_port is not None: - port_in_use = _check_port(self.server_port) - if port_in_use: - print( - f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})" - ) + passages_file = kwargs.get("passages_file") + assert isinstance(passages_file, str), "passages_file must be a string" - # Check model compatibility - model_matches = _check_server_model(self.server_port, model_name) - if model_matches: - print( - f"✅ Existing server already using correct model: {model_name}" - ) - - # Still check meta path if provided - passages_file = kwargs.get("passages_file") - if passages_file and str(passages_file).endswith( - ".meta.json" - ): - meta_matches = _check_server_meta_path( - self.server_port, str(passages_file) - ) - if not meta_matches: - print("⚠️ Updating meta path to: {passages_file}") - _update_server_meta_path( - self.server_port, str(passages_file) - ) - - return True - else: - print( - f"⚠️ Existing server has different model. Attempting to update to: {model_name}" - ) - if not _update_server_model(self.server_port, model_name): - print( - "❌ Failed to update existing server model. Restarting server..." - ) - self.stop_server() - # Continue to start new server below - else: - print( - f"✅ Successfully updated existing server model to: {model_name}" - ) + # Check if we have a compatible running server + if self._has_compatible_running_server(model_name, passages_file): + assert self.server_port is not None, ( + "a compatible running server should set server_port" + ) + return True, self.server_port - # Also check meta path if provided - passages_file = kwargs.get("passages_file") - if passages_file and str(passages_file).endswith( - ".meta.json" - ): - meta_matches = _check_server_meta_path( - self.server_port, str(passages_file) - ) - if not meta_matches: - print("⚠️ Updating meta path to: {passages_file}") - _update_server_meta_path( - self.server_port, str(passages_file) - ) + # Find available port (compatible or free) + try: + actual_port, is_compatible = _find_compatible_port_or_next_available( + port, model_name, passages_file + ) + except RuntimeError as e: + print(f"❌ {e}") + return False, port - return True - else: - # Server process exists but port not responding - restart - print("⚠️ Server process exists but not responding. Restarting...") - self.stop_server() - # Continue to start new server below - else: - # No port stored - restart - print("⚠️ No port information stored. Restarting server...") - self.stop_server() - # Continue to start new server below + if is_compatible: + print(f"✅ Using existing compatible server on port {actual_port}") + self.server_port = actual_port + self.server_process = None # We don't own this process + return True, actual_port - if _check_port(port): - # Port is in use, check if it's using the correct meta file and model - passages_file = kwargs.get("passages_file") + if actual_port != port: + print(f"⚠️ Using port {actual_port} instead of {port}") - print(f"INFO: Port {port} is in use. Checking server compatibility...") + # Start new server + return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) - # Check model compatibility first - model_matches = _check_server_model(port, model_name) - if model_matches: - print( - f"✅ Existing server on port {port} is using correct model: {model_name}" - ) - else: - print( - f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}" - ) - if not _update_server_model(port, model_name): - raise RuntimeError( - f"❌ Failed to update server model to {model_name}. Consider using a different port." - ) - print(f"✅ Successfully updated server model to: {model_name}") + def _has_compatible_running_server( + self, model_name: str, passages_file: str + ) -> bool: + """Check if we have a compatible running server.""" + if not ( + self.server_process + and self.server_process.poll() is None + and self.server_port + ): + return False - # Check meta path compatibility if provided - if passages_file and str(passages_file).endswith(".meta.json"): - meta_matches = _check_server_meta_path(port, str(passages_file)) - if not meta_matches: - print( - f"⚠️ Existing server on port {port} has different meta path. Attempting to update..." - ) - if not _update_server_meta_path(port, str(passages_file)): - raise RuntimeError( - "❌ Failed to update server meta path. This may cause data synchronization issues." - ) - print( - f"✅ Successfully updated server meta path to: {passages_file}" - ) - else: - print( - f"✅ Existing server on port {port} is using correct meta path: {passages_file}" - ) - - print(f"✅ Server on port {port} is compatible and ready to use.") + if _check_process_matches_config(self.server_port, model_name, passages_file): + print( + f"✅ Existing server process (PID {self.server_process.pid}) is compatible" + ) return True - print( - f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..." - ) + print("⚠️ Existing server process is incompatible. Should start a new server.") + return False + + def _start_new_server( + self, port: int, model_name: str, embedding_mode: str, **kwargs + ) -> tuple[bool, int]: + """Start a new embedding server on the given port.""" + print(f"INFO: Starting embedding server on port {port}...") + + command = self._build_server_command(port, model_name, embedding_mode, **kwargs) try: - command = [ - sys.executable, - "-m", - self.backend_module_name, - "--zmq-port", - str(port), - "--model-name", - model_name, - ] - - # Add extra arguments for specific backends - if "passages_file" in kwargs and kwargs["passages_file"]: - command.extend(["--passages-file", str(kwargs["passages_file"])]) - # if "distance_metric" in kwargs and kwargs["distance_metric"]: - # command.extend(["--distance-metric", kwargs["distance_metric"]]) - if embedding_mode != "sentence-transformers": - command.extend(["--embedding-mode", embedding_mode]) - if "enable_warmup" in kwargs and not kwargs["enable_warmup"]: - command.extend(["--disable-warmup"]) - - project_root = Path(__file__).parent.parent.parent.parent.parent - print(f"INFO: Running command from project root: {project_root}") - print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command - - self.server_process = subprocess.Popen( - command, - cwd=project_root, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring - text=True, - encoding="utf-8", - bufsize=1, # Line buffered - universal_newlines=True, - ) - self.server_port = port - print(f"INFO: Server process started with PID: {self.server_process.pid}") - - max_wait, wait_interval = 120, 0.5 - for _ in range(int(max_wait / wait_interval)): - if _check_port(port): - print("✅ Embedding server is up and ready for this session.") - log_thread = threading.Thread(target=self._log_monitor, daemon=True) - log_thread.start() - return True - if self.server_process.poll() is not None: - print( - "❌ ERROR: Server process terminated unexpectedly during startup." - ) - self._print_recent_output() - return False - time.sleep(wait_interval) - - print( - f"❌ ERROR: Server process failed to start listening within {max_wait} seconds." - ) - self.stop_server() - return False - + self._launch_server_process(command, port) + return self._wait_for_server_ready(port) except Exception as e: - print(f"❌ ERROR: Failed to start embedding server process: {e}") - return False + print(f"❌ ERROR: Failed to start embedding server: {e}") + return False, port + + def _build_server_command( + self, port: int, model_name: str, embedding_mode: str, **kwargs + ) -> list: + """Build the command to start the embedding server.""" + command = [ + sys.executable, + "-m", + self.backend_module_name, + "--zmq-port", + str(port), + "--model-name", + model_name, + ] + + if kwargs.get("passages_file"): + command.extend(["--passages-file", str(kwargs["passages_file"])]) + if embedding_mode != "sentence-transformers": + command.extend(["--embedding-mode", embedding_mode]) + + return command + + def _launch_server_process(self, command: list, port: int) -> None: + """Launch the server process.""" + project_root = Path(__file__).parent.parent.parent.parent.parent + print(f"INFO: Command: {' '.join(command)}") + + self.server_process = subprocess.Popen( + command, + cwd=project_root, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + encoding="utf-8", + bufsize=1, + universal_newlines=True, + ) + self.server_port = port + print(f"INFO: Server process started with PID: {self.server_process.pid}") + + # Register atexit callback only when we actually start a process + if not self._atexit_registered: + # Use a lambda to avoid issues with bound methods + atexit.register(lambda: self.stop_server() if self.server_process else None) + self._atexit_registered = True + + def _wait_for_server_ready(self, port: int) -> tuple[bool, int]: + """Wait for the server to be ready.""" + max_wait, wait_interval = 120, 0.5 + for _ in range(int(max_wait / wait_interval)): + if _check_port(port): + print("✅ Embedding server is ready!") + threading.Thread(target=self._log_monitor, daemon=True).start() + return True, port + + if self.server_process.poll() is not None: + print("❌ ERROR: Server terminated during startup.") + self._print_recent_output() + return False, port + + time.sleep(wait_interval) + + print(f"❌ ERROR: Server failed to start within {max_wait} seconds.") + self.stop_server() + return False, port def _print_recent_output(self): """Print any recent output from the server process.""" if not self.server_process or not self.server_process.stdout: return try: - # Read any available output - if select.select([self.server_process.stdout], [], [], 0)[0]: output = self.server_process.stdout.read() if output: @@ -404,17 +339,26 @@ class EmbeddingServerManager: def stop_server(self): """Stops the embedding server process if it's running.""" - if self.server_process and self.server_process.poll() is None: + if not self.server_process: + return + + if self.server_process.poll() is not None: + # Process already terminated + self.server_process = None + return + + print( + f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..." + ) + self.server_process.terminate() + + try: + self.server_process.wait(timeout=5) + print(f"INFO: Server process {self.server_process.pid} terminated.") + except subprocess.TimeoutExpired: print( - f"INFO: Terminating session server process (PID: {self.server_process.pid})..." + f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it." ) - self.server_process.terminate() - try: - self.server_process.wait(timeout=5) - print("INFO: Server process terminated.") - except subprocess.TimeoutExpired: - print( - "WARNING: Server process did not terminate gracefully, killing it." - ) - self.server_process.kill() + self.server_process.kill() + self.server_process = None diff --git a/packages/leann-core/src/leann/registry.py b/packages/leann-core/src/leann/registry.py index bda797a..043a784 100644 --- a/packages/leann-core/src/leann/registry.py +++ b/packages/leann-core/src/leann/registry.py @@ -7,30 +7,37 @@ import importlib.metadata if TYPE_CHECKING: from leann.interface import LeannBackendFactoryInterface -BACKEND_REGISTRY: Dict[str, 'LeannBackendFactoryInterface'] = {} +BACKEND_REGISTRY: Dict[str, "LeannBackendFactoryInterface"] = {} + def register_backend(name: str): """A decorator to register a new backend class.""" + def decorator(cls): print(f"INFO: Registering backend '{name}'") BACKEND_REGISTRY[name] = cls return cls + return decorator + def autodiscover_backends(): """Automatically discovers and imports all 'leann-backend-*' packages.""" - print("INFO: Starting backend auto-discovery...") + # print("INFO: Starting backend auto-discovery...") discovered_backends = [] for dist in importlib.metadata.distributions(): - dist_name = dist.metadata['name'] - if dist_name.startswith('leann-backend-'): - backend_module_name = dist_name.replace('-', '_') + dist_name = dist.metadata["name"] + if dist_name.startswith("leann-backend-"): + backend_module_name = dist_name.replace("-", "_") discovered_backends.append(backend_module_name) - - for backend_module_name in sorted(discovered_backends): # sort for deterministic loading + + for backend_module_name in sorted( + discovered_backends + ): # sort for deterministic loading try: importlib.import_module(backend_module_name) # Registration message is printed by the decorator except ImportError as e: - print(f"WARN: Could not import backend module '{backend_module_name}': {e}") - print("INFO: Backend auto-discovery finished.") \ No newline at end of file + # print(f"WARN: Could not import backend module '{backend_module_name}': {e}") + pass + # print("INFO: Backend auto-discovery finished.") diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index 0f40a85..dfa6c2d 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -43,8 +43,6 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): "WARNING: embedding_model not found in meta.json. Recompute will fail." ) - self.label_map = self._load_label_map() - self.embedding_server_manager = EmbeddingServerManager( backend_module_name=backend_module_name ) @@ -58,17 +56,9 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): with open(meta_path, "r", encoding="utf-8") as f: return json.load(f) - def _load_label_map(self) -> Dict[int, str]: - """Loads the mapping from integer IDs to string IDs.""" - label_map_file = self.index_dir / "leann.labels.map" - if not label_map_file.exists(): - raise FileNotFoundError(f"Label map file not found: {label_map_file}") - with open(label_map_file, "rb") as f: - return pickle.load(f) - def _ensure_server_running( self, passages_source_file: str, port: int, **kwargs - ) -> None: + ) -> int: """ Ensures the embedding server is running if recompute is needed. This is a helper for subclasses. @@ -79,8 +69,8 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): ) embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") - - server_started = self.embedding_server_manager.start_server( + + server_started, actual_port = self.embedding_server_manager.start_server( port=port, model_name=self.embedding_model, passages_file=passages_source_file, @@ -89,7 +79,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): enable_warmup=kwargs.get("enable_warmup", False), ) if not server_started: - raise RuntimeError(f"Failed to start embedding server on port {port}") + raise RuntimeError( + f"Failed to start embedding server on port {actual_port}" + ) + + return actual_port def compute_query_embedding( self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True @@ -106,12 +100,16 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): Query embedding as numpy array """ # Try to use embedding server if available and requested - if ( - use_server_if_available - and self.embedding_server_manager - and self.embedding_server_manager.server_process - ): + if use_server_if_available: try: + # Ensure we have a server with passages_file for compatibility + passages_source_file = ( + self.index_dir / f"{self.index_path.name}.meta.json" + ) + zmq_port = self._ensure_server_running( + str(passages_source_file), zmq_port + ) + return self._compute_embedding_via_server([query], zmq_port)[ 0:1 ] # Return (1, D) shape diff --git a/pyproject.toml b/pyproject.toml index 3a0c027..b42dcbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "llama-index-embeddings-huggingface>=0.5.5", "mlx>=0.26.3", "mlx-lm>=0.26.0", + "psutil>=5.8.0", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index aa661f6..6c84ad2 100644 --- a/uv.lock +++ b/uv.lock @@ -1834,10 +1834,14 @@ source = { editable = "packages/leann-core" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "tqdm" }, ] [package.metadata] -requires-dist = [{ name = "numpy", specifier = ">=1.20.0" }] +requires-dist = [ + { name = "numpy", specifier = ">=1.20.0" }, + { name = "tqdm", specifier = ">=4.60.0" }, +] [[package]] name = "leann-workspace" @@ -1851,7 +1855,6 @@ dependencies = [ { name = "flask" }, { name = "flask-compress" }, { name = "ipykernel" }, - { name = "leann-backend-diskann" }, { name = "leann-backend-hnsw" }, { name = "leann-core" }, { name = "llama-index" }, @@ -1867,6 +1870,7 @@ dependencies = [ { name = "ollama" }, { name = "openai" }, { name = "protobuf" }, + { name = "psutil" }, { name = "pypdf2" }, { name = "requests" }, { name = "sentence-transformers" }, @@ -1884,6 +1888,9 @@ dev = [ { name = "pytest-cov" }, { name = "ruff" }, ] +diskann = [ + { name = "leann-backend-diskann" }, +] [package.metadata] requires-dist = [ @@ -1896,7 +1903,7 @@ requires-dist = [ { name = "flask-compress" }, { name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" }, { name = "ipykernel", specifier = "==6.29.5" }, - { name = "leann-backend-diskann", editable = "packages/leann-backend-diskann" }, + { name = "leann-backend-diskann", marker = "extra == 'diskann'", editable = "packages/leann-backend-diskann" }, { name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" }, { name = "leann-core", editable = "packages/leann-core" }, { name = "llama-index", specifier = ">=0.12.44" }, @@ -1912,6 +1919,7 @@ requires-dist = [ { name = "ollama" }, { name = "openai", specifier = ">=1.0.0" }, { name = "protobuf", specifier = "==4.25.3" }, + { name = "psutil", specifier = ">=5.8.0" }, { name = "pypdf2", specifier = ">=3.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" }, @@ -1922,7 +1930,7 @@ requires-dist = [ { name = "torch" }, { name = "tqdm" }, ] -provides-extras = ["dev"] +provides-extras = ["dev", "diskann"] [[package]] name = "llama-cloud"