diff --git a/README.md b/README.md
index 040a109..3624c81 100755
--- a/README.md
+++ b/README.md
@@ -292,6 +292,71 @@ Once the index is built, you can ask questions like:
+## π₯οΈ Command Line Interface
+
+LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
+
+```bash
+# Build an index from documents
+leann build my-docs --docs ./documents
+
+# Search your documents
+leann search my-docs "machine learning concepts"
+
+# Interactive chat with your documents
+leann ask my-docs --interactive
+
+# List all your indexes
+leann list
+```
+
+**Key CLI features:**
+- Auto-detects document formats (PDF, TXT, MD, DOCX)
+- Smart text chunking with overlap
+- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
+- Organized index storage in `~/.leann/indexes/`
+- Support for advanced search parameters
+
+
+π Click to expand: Complete CLI Reference
+
+**Build Command:**
+```bash
+leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
+
+Options:
+ --backend {hnsw,diskann} Backend to use (default: hnsw)
+ --embedding-model MODEL Embedding model (default: facebook/contriever)
+ --graph-degree N Graph degree (default: 32)
+ --complexity N Build complexity (default: 64)
+ --force Force rebuild existing index
+ --compact Use compact storage (default: true)
+ --recompute Enable recomputation (default: true)
+```
+
+**Search Command:**
+```bash
+leann search INDEX_NAME QUERY [OPTIONS]
+
+Options:
+ --top-k N Number of results (default: 5)
+ --complexity N Search complexity (default: 64)
+ --recompute-embeddings Use recomputation for highest accuracy
+ --pruning-strategy {global,local,proportional}
+```
+
+**Ask Command:**
+```bash
+leann ask INDEX_NAME [OPTIONS]
+
+Options:
+ --llm {ollama,openai,hf} LLM provider (default: ollama)
+ --model MODEL Model name (default: qwen3:8b)
+ --interactive Interactive chat mode
+ --top-k N Retrieval count (default: 20)
+```
+
+
## ποΈ Architecture & How It Works
diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py
index a28a744..b4df2f7 100644
--- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py
+++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py
@@ -1,10 +1,13 @@
import numpy as np
import os
import struct
+import sys
from pathlib import Path
-from typing import Dict, Any, List, Literal
+from typing import Dict, Any, List, Literal, Optional
import contextlib
+import logging
+
from leann.searcher_base import BaseSearcher
from leann.registry import register_backend
from leann.interface import (
@@ -13,6 +16,46 @@ from leann.interface import (
LeannBackendSearcherInterface,
)
+logger = logging.getLogger(__name__)
+
+
+@contextlib.contextmanager
+def suppress_cpp_output_if_needed():
+ """Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
+ log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
+
+ # Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
+ should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
+
+ if not should_suppress:
+ # Don't suppress, just yield
+ yield
+ return
+
+ # Save original file descriptors
+ stdout_fd = sys.stdout.fileno()
+ stderr_fd = sys.stderr.fileno()
+
+ # Save original stdout/stderr
+ stdout_dup = os.dup(stdout_fd)
+ stderr_dup = os.dup(stderr_fd)
+
+ try:
+ # Redirect to /dev/null
+ devnull = os.open(os.devnull, os.O_WRONLY)
+ os.dup2(devnull, stdout_fd)
+ os.dup2(devnull, stderr_fd)
+ os.close(devnull)
+
+ yield
+
+ finally:
+ # Restore original file descriptors
+ os.dup2(stdout_dup, stdout_fd)
+ os.dup2(stderr_dup, stderr_fd)
+ os.close(stdout_dup)
+ os.close(stderr_dup)
+
def _get_diskann_metrics():
from . import _diskannpy as diskannpy # type: ignore
@@ -64,6 +107,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
+ logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
data_filename = f"{index_prefix}_data.bin"
@@ -74,7 +118,9 @@ class DiskannBuilder(LeannBackendBuilderInterface):
build_kwargs.get("distance_metric", "mips").lower()
)
if metric_enum is None:
- raise ValueError("Unsupported distance_metric.")
+ raise ValueError(
+ f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
+ )
try:
from . import _diskannpy as diskannpy # type: ignore
@@ -96,36 +142,40 @@ class DiskannBuilder(LeannBackendBuilderInterface):
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
+ logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
class DiskannSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
super().__init__(
index_path,
- backend_module_name="leann_backend_diskann.embedding_server",
+ backend_module_name="leann_backend_diskann.diskann_embedding_server",
**kwargs,
)
- from . import _diskannpy as diskannpy # type: ignore
- distance_metric = kwargs.get("distance_metric", "mips").lower()
- metric_enum = _get_diskann_metrics().get(distance_metric)
- if metric_enum is None:
- raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
+ # Initialize DiskANN index with suppressed C++ output based on log level
+ with suppress_cpp_output_if_needed():
+ from . import _diskannpy as diskannpy # type: ignore
- self.num_threads = kwargs.get("num_threads", 8)
- self.zmq_port = kwargs.get("zmq_port", 6666)
+ distance_metric = kwargs.get("distance_metric", "mips").lower()
+ metric_enum = _get_diskann_metrics().get(distance_metric)
+ if metric_enum is None:
+ raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
- full_index_prefix = str(self.index_dir / self.index_path.stem)
- self._index = diskannpy.StaticDiskFloatIndex(
- metric_enum,
- full_index_prefix,
- self.num_threads,
- kwargs.get("num_nodes_to_cache", 0),
- 1,
- self.zmq_port,
- "",
- "",
- )
+ self.num_threads = kwargs.get("num_threads", 8)
+
+ fake_zmq_port = 6666
+ full_index_prefix = str(self.index_dir / self.index_path.stem)
+ self._index = diskannpy.StaticDiskFloatIndex(
+ metric_enum,
+ full_index_prefix,
+ self.num_threads,
+ kwargs.get("num_nodes_to_cache", 0),
+ 1,
+ fake_zmq_port, # Initial port, can be updated at runtime
+ "",
+ "",
+ )
def search(
self,
@@ -136,7 +186,7 @@ class DiskannSearcher(BaseSearcher):
prune_ratio: float = 0.0,
recompute_embeddings: bool = False,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
- zmq_port: int = 5557,
+ zmq_port: Optional[int] = None,
batch_recompute: bool = False,
dedup_node_dis: bool = False,
**kwargs,
@@ -155,7 +205,7 @@ class DiskannSearcher(BaseSearcher):
- "global": Use global pruning strategy (default)
- "local": Use local pruning strategy
- "proportional": Not supported in DiskANN, falls back to global
- zmq_port: ZMQ port for embedding server
+ zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
@@ -163,22 +213,25 @@ class DiskannSearcher(BaseSearcher):
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
+ # Handle zmq_port compatibility: DiskANN can now update port at runtime
+ if recompute_embeddings:
+ if zmq_port is None:
+ raise ValueError(
+ "zmq_port must be provided if recompute_embeddings is True"
+ )
+ current_port = self._index.get_zmq_port()
+ if zmq_port != current_port:
+ logger.debug(
+ f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
+ )
+ self._index.set_zmq_port(zmq_port)
+
# DiskANN doesn't support "proportional" strategy
if pruning_strategy == "proportional":
raise NotImplementedError(
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
)
- # Use recompute_embeddings parameter
- use_recompute = recompute_embeddings
- if use_recompute:
- meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
- if not meta_file_path.exists():
- raise RuntimeError(
- f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
- )
- self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
-
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -188,21 +241,23 @@ class DiskannSearcher(BaseSearcher):
else: # "global"
use_global_pruning = True
- labels, distances = self._index.batch_search(
- query,
- query.shape[0],
- top_k,
- complexity,
- beam_width,
- self.num_threads,
- kwargs.get("USE_DEFERRED_FETCH", False),
- kwargs.get("skip_search_reorder", False),
- use_recompute,
- dedup_node_dis,
- prune_ratio,
- batch_recompute,
- use_global_pruning,
- )
+ # Perform search with suppressed C++ output based on log level
+ with suppress_cpp_output_if_needed():
+ labels, distances = self._index.batch_search(
+ query,
+ query.shape[0],
+ top_k,
+ complexity,
+ beam_width,
+ self.num_threads,
+ kwargs.get("USE_DEFERRED_FETCH", False),
+ kwargs.get("skip_search_reorder", False),
+ recompute_embeddings,
+ dedup_node_dis,
+ prune_ratio,
+ batch_recompute,
+ use_global_pruning,
+ )
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py
new file mode 100644
index 0000000..18bcd09
--- /dev/null
+++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py
@@ -0,0 +1,269 @@
+"""
+DiskANN-specific embedding server
+"""
+
+import argparse
+import threading
+import time
+import os
+import zmq
+import numpy as np
+import json
+from pathlib import Path
+from typing import Optional
+import sys
+import logging
+
+# Set up logging based on environment variable
+LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
+logger = logging.getLogger(__name__)
+
+# Force set logger level (don't rely on basicConfig in subprocess)
+log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
+logger.setLevel(log_level)
+
+# Ensure we have a handler if none exists
+if not logger.handlers:
+ handler = logging.StreamHandler()
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+ logger.propagate = False
+
+
+def create_diskann_embedding_server(
+ passages_file: Optional[str] = None,
+ zmq_port: int = 5555,
+ model_name: str = "sentence-transformers/all-mpnet-base-v2",
+ embedding_mode: str = "sentence-transformers",
+):
+ """
+ Create and start a ZMQ-based embedding server for DiskANN backend.
+ Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
+ """
+ logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
+ logger.info(f"Using embedding mode: {embedding_mode}")
+
+ # Add leann-core to path for unified embedding computation
+ current_dir = Path(__file__).parent
+ leann_core_path = current_dir.parent.parent / "leann-core" / "src"
+ sys.path.insert(0, str(leann_core_path))
+
+ try:
+ from leann.embedding_compute import compute_embeddings
+ from leann.api import PassageManager
+
+ logger.info("Successfully imported unified embedding computation module")
+ except ImportError as e:
+ logger.error(f"Failed to import embedding computation module: {e}")
+ return
+ finally:
+ sys.path.pop(0)
+
+ # Check port availability
+ import socket
+
+ def check_port(port):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ return s.connect_ex(("localhost", port)) == 0
+
+ if check_port(zmq_port):
+ logger.error(f"Port {zmq_port} is already in use")
+ return
+
+ # Only support metadata file, fail fast for everything else
+ if not passages_file or not passages_file.endswith(".meta.json"):
+ raise ValueError("Only metadata files (.meta.json) are supported")
+
+ # Load metadata to get passage sources
+ with open(passages_file, "r") as f:
+ meta = json.load(f)
+
+ passages = PassageManager(meta["passage_sources"])
+ logger.info(
+ f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
+ )
+
+ # Import protobuf after ensuring the path is correct
+ try:
+ from . import embedding_pb2
+ except ImportError as e:
+ logger.error(f"Failed to import protobuf module: {e}")
+ return
+
+ def zmq_server_thread():
+ """ZMQ server thread using REP socket for universal compatibility"""
+ context = zmq.Context()
+ socket = context.socket(zmq.REP) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
+ socket.bind(f"tcp://*:{zmq_port}")
+ logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
+
+ socket.setsockopt(zmq.RCVTIMEO, 300000)
+ socket.setsockopt(zmq.SNDTIMEO, 300000)
+
+ while True:
+ try:
+ # REP socket receives single-part messages
+ message = socket.recv()
+
+ # Check for empty messages - REP socket requires response to every request
+ if len(message) == 0:
+ logger.debug("Received empty message, sending empty response")
+ socket.send(b"") # REP socket must respond to every request
+ continue
+
+ logger.debug(f"Received ZMQ request of size {len(message)} bytes")
+ logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
+
+ e2e_start = time.time()
+
+ # Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
+ texts = []
+ node_ids = []
+ is_text_request = False
+
+ try:
+ req_proto = embedding_pb2.NodeEmbeddingRequest()
+ req_proto.ParseFromString(message)
+ node_ids = list(req_proto.node_ids)
+
+ if not node_ids:
+ raise RuntimeError(f"PROTOBUF: Received empty node_ids! Message size: {len(message)}")
+
+ logger.info(
+ f"β
PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
+ )
+ except Exception as protobuf_error:
+ logger.debug(f"Protobuf parsing failed: {protobuf_error}")
+ # Fallback to msgpack (for BaseSearcher direct text requests)
+ try:
+ import msgpack
+
+ request = msgpack.unpackb(message)
+ # For BaseSearcher compatibility, request is a list of texts directly
+ if isinstance(request, list) and all(
+ isinstance(item, str) for item in request
+ ):
+ texts = request
+ is_text_request = True
+ logger.info(
+ f"β
MSGPACK: Direct text request for {len(texts)} texts"
+ )
+ else:
+ raise ValueError("Not a valid msgpack text request")
+ except Exception as msgpack_error:
+ raise RuntimeError(
+ f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
+ )
+
+ # Look up texts by node IDs (only if not direct text request)
+ if not is_text_request:
+ for nid in node_ids:
+ try:
+ passage_data = passages.get_passage(str(nid))
+ txt = passage_data["text"]
+ if not txt:
+ raise RuntimeError(
+ f"FATAL: Empty text for passage ID {nid}"
+ )
+ texts.append(txt)
+ except KeyError as e:
+ logger.error(f"Passage ID {nid} not found: {e}")
+ raise e
+ except Exception as e:
+ logger.error(f"Exception looking up passage ID {nid}: {e}")
+ raise
+
+ # Debug logging
+ logger.debug(
+ f"Processing {len(texts)} texts"
+ )
+ logger.debug(
+ f"Text lengths: {[len(t) for t in texts[:5]]}"
+ ) # Show first 5
+
+ # Process embeddings using unified computation
+ embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
+ logger.info(
+ f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
+ )
+
+ # Prepare response based on request type
+ if is_text_request:
+ # For BaseSearcher compatibility: return msgpack format
+ import msgpack
+
+ response_data = msgpack.packb(embeddings.tolist())
+ else:
+ # For DiskANN C++ compatibility: return protobuf format
+ resp_proto = embedding_pb2.NodeEmbeddingResponse()
+ hidden_contiguous = np.ascontiguousarray(
+ embeddings, dtype=np.float32
+ )
+
+ # Serialize embeddings data
+ resp_proto.embeddings_data = hidden_contiguous.tobytes()
+ resp_proto.dimensions.append(hidden_contiguous.shape[0])
+ resp_proto.dimensions.append(hidden_contiguous.shape[1])
+
+ response_data = resp_proto.SerializeToString()
+
+ # Send response back to the client
+ socket.send(response_data)
+
+ e2e_end = time.time()
+ logger.info(f"β±οΈ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
+
+ except zmq.Again:
+ logger.debug("ZMQ socket timeout, continuing to listen")
+ continue
+ except Exception as e:
+ logger.error(f"Error in ZMQ server loop: {e}")
+ import traceback
+ traceback.print_exc()
+ raise
+
+ zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
+ zmq_thread.start()
+ logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
+
+ # Keep the main thread alive
+ try:
+ while True:
+ time.sleep(1)
+ except KeyboardInterrupt:
+ logger.info("DiskANN Server shutting down...")
+ return
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="DiskANN Embedding service")
+ parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
+ parser.add_argument(
+ "--passages-file",
+ type=str,
+ help="Metadata JSON file containing passage sources",
+ )
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="sentence-transformers/all-mpnet-base-v2",
+ help="Embedding model name",
+ )
+ parser.add_argument(
+ "--embedding-mode",
+ type=str,
+ default="sentence-transformers",
+ choices=["sentence-transformers", "openai", "mlx"],
+ help="Embedding backend mode",
+ )
+
+ args = parser.parse_args()
+
+ # Create and start the DiskANN embedding server
+ create_diskann_embedding_server(
+ passages_file=args.passages_file,
+ zmq_port=args.zmq_port,
+ model_name=args.model_name,
+ embedding_mode=args.embedding_mode,
+ )
diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py
deleted file mode 100644
index 04f7f56..0000000
--- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py
+++ /dev/null
@@ -1,705 +0,0 @@
-#!/usr/bin/env python3
-"""
-Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
-"""
-
-import pickle
-import argparse
-import time
-import json
-from typing import Dict, Any, Optional, Union
-
-from transformers import AutoTokenizer, AutoModel
-import os
-from contextlib import contextmanager
-import zmq
-import numpy as np
-import msgpack
-from pathlib import Path
-import logging
-
-RED = "\033[91m"
-
-# Set up logging based on environment variable
-LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
-logging.basicConfig(
- level=getattr(logging, LOG_LEVEL, logging.INFO),
- format='%(asctime)s - %(levelname)s - %(message)s'
-)
-logger = logging.getLogger(__name__)
-RESET = "\033[0m"
-
-# --- New Passage Loader from HNSW backend ---
-class SimplePassageLoader:
- """
- Simple passage loader that replaces config.py dependencies
- """
- def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
- self.passages_data = passages_data or {}
- self._meta_path = ''
-
- def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
- """Get passage by ID"""
- str_id = str(passage_id)
- if str_id in self.passages_data:
- return {"text": self.passages_data[str_id]}
- else:
- # Return empty text for missing passages
- return {"text": ""}
-
- def __len__(self) -> int:
- return len(self.passages_data)
-
- def keys(self):
- return self.passages_data.keys()
-
-def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
- """
- Load passages using metadata file with PassageManager for lazy loading
- """
- # Load metadata to get passage sources
- with open(meta_file, 'r') as f:
- meta = json.load(f)
-
- # Import PassageManager dynamically to avoid circular imports
- import sys
- from pathlib import Path
-
- # Find the leann package directory relative to this file
- current_dir = Path(__file__).parent
- leann_core_path = current_dir.parent.parent / "leann-core" / "src"
- sys.path.insert(0, str(leann_core_path))
-
- try:
- from leann.api import PassageManager
- passage_manager = PassageManager(meta['passage_sources'])
- finally:
- sys.path.pop(0)
-
- print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
-
- class LazyPassageLoader(SimplePassageLoader):
- def __init__(self, passage_manager):
- self.passage_manager = passage_manager
- # Initialize parent with empty data
- super().__init__({})
-
- def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
- """Get passage by ID with lazy loading"""
- try:
- int_id = int(passage_id)
- string_id = str(int_id)
- passage_data = self.passage_manager.get_passage(string_id)
- if passage_data and passage_data.get("text"):
- return {"text": passage_data["text"]}
- else:
- raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
- except Exception as e:
- raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
-
- def __len__(self) -> int:
- return len(self.passage_manager.global_offset_map)
-
- def keys(self):
- return self.passage_manager.global_offset_map.keys()
-
- loader = LazyPassageLoader(passage_manager)
- loader._meta_path = meta_file
- return loader
-
-def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
- """
- Load passages from a JSONL file with label map support
- Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
- """
-
- if not os.path.exists(passages_file):
- raise FileNotFoundError(f"Passages file {passages_file} not found.")
-
- if not passages_file.endswith('.jsonl'):
- raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
-
- # Load passages directly by their sequential IDs
- passages_data = {}
- with open(passages_file, 'r', encoding='utf-8') as f:
- for line in f:
- if line.strip():
- passage = json.loads(line)
- passages_data[passage['id']] = passage['text']
-
- print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
- return SimplePassageLoader(passages_data)
-
-def create_embedding_server_thread(
- zmq_port=5555,
- model_name="sentence-transformers/all-mpnet-base-v2",
- max_batch_size=128,
- passages_file: Optional[str] = None,
- embedding_mode: str = "sentence-transformers",
- enable_warmup: bool = False,
-):
- """
- Create and run embedding server in the current thread
- This function is designed to be called in a separate thread
- """
- logger.info(f"Initializing embedding server thread on port {zmq_port}")
-
- try:
- # Check if port is already occupied
- import socket
- def check_port(port):
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- return s.connect_ex(('localhost', port)) == 0
-
- if check_port(zmq_port):
- print(f"{RED}Port {zmq_port} is already in use{RESET}")
- return
-
- # Auto-detect mode based on model name if not explicitly set
- if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
- embedding_mode = "openai"
-
- if embedding_mode == "mlx":
- from leann.api import compute_embeddings_mlx
- import torch
- logger.info("Using MLX for embeddings")
- # Set device to CPU for compatibility with DeviceTimer class
- device = torch.device("cpu")
- cuda_available = False
- mps_available = False
- elif embedding_mode == "openai":
- from leann.api import compute_embeddings_openai
- import torch
- logger.info("Using OpenAI API for embeddings")
- # Set device to CPU for compatibility with DeviceTimer class
- device = torch.device("cpu")
- cuda_available = False
- mps_available = False
- elif embedding_mode == "sentence-transformers":
- # Initialize model
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
- import torch
-
- # Select device
- mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
- cuda_available = torch.cuda.is_available()
-
- if cuda_available:
- device = torch.device("cuda")
- logger.info("Using CUDA device")
- elif mps_available:
- device = torch.device("mps")
- logger.info("Using MPS device (Apple Silicon)")
- else:
- device = torch.device("cpu")
- logger.info("Using CPU device")
-
- # Load model
- logger.info(f"Loading model {model_name}")
- model = AutoModel.from_pretrained(model_name).to(device).eval()
-
- # Optimize model
- if cuda_available or mps_available:
- try:
- model = model.half()
- model = torch.compile(model)
- logger.info(f"Using FP16 precision with model: {model_name}")
- except Exception as e:
- print(f"WARNING: Model optimization failed: {e}")
- else:
- raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
-
- # Load passages from file if provided
- if passages_file and os.path.exists(passages_file):
- # Check if it's a metadata file or a single passages file
- if passages_file.endswith('.meta.json'):
- passages = load_passages_from_metadata(passages_file)
- else:
- # Try to find metadata file in same directory
- passages_dir = Path(passages_file).parent
- meta_files = list(passages_dir.glob("*.meta.json"))
- if meta_files:
- print(f"Found metadata file: {meta_files[0]}, using lazy loading")
- passages = load_passages_from_metadata(str(meta_files[0]))
- else:
- # Fallback to original single file loading (will cause warnings)
- print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
- passages = load_passages_from_file(passages_file)
- else:
- print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
- passages = SimplePassageLoader()
-
- logger.info(f"Loaded {len(passages)} passages.")
-
- def client_warmup(zmq_port):
- """Perform client-side warmup for DiskANN server"""
- time.sleep(2)
- print(f"Performing client-side warmup with model {model_name}...")
-
- # Get actual passage IDs from the loaded passages
- sample_ids = []
- if hasattr(passages, 'keys') and len(passages) > 0:
- available_ids = list(passages.keys())
- # Take up to 5 actual IDs, but at least 1
- sample_ids = available_ids[:min(5, len(available_ids))]
- print(f"Using actual passage IDs for warmup: {sample_ids}")
- else:
- print("No passages available for warmup, skipping warmup...")
- return
-
- try:
- context = zmq.Context()
- socket = context.socket(zmq.REQ)
- socket.connect(f"tcp://localhost:{zmq_port}")
- socket.setsockopt(zmq.RCVTIMEO, 30000)
- socket.setsockopt(zmq.SNDTIMEO, 30000)
-
- try:
- ids_to_send = [int(x) for x in sample_ids]
- except ValueError:
- print("Warning: Could not convert sample IDs to integers, skipping warmup")
- return
-
- if not ids_to_send:
- print("Skipping warmup send.")
- return
-
- # Use protobuf format for warmup
- from . import embedding_pb2
- req_proto = embedding_pb2.NodeEmbeddingRequest()
- req_proto.node_ids.extend(ids_to_send)
- request_bytes = req_proto.SerializeToString()
-
- for i in range(3):
- print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
- socket.send(request_bytes)
- response_bytes = socket.recv()
-
- resp_proto = embedding_pb2.NodeEmbeddingResponse()
- resp_proto.ParseFromString(response_bytes)
- embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
- print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
- time.sleep(0.1)
-
- print("Client-side Protobuf ZMQ warmup complete")
- socket.close()
- context.term()
- except Exception as e:
- print(f"Error during Protobuf ZMQ warmup: {e}")
-
- class DeviceTimer:
- """Device timer"""
- def __init__(self, name="", device=device):
- self.name = name
- self.device = device
- self.start_time = 0
- self.end_time = 0
-
- if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
- self.start_event = torch.cuda.Event(enable_timing=True)
- self.end_event = torch.cuda.Event(enable_timing=True)
- else:
- self.start_event = None
- self.end_event = None
-
- @contextmanager
- def timing(self):
- self.start()
- yield
- self.end()
-
- def start(self):
- if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
- torch.cuda.synchronize()
- self.start_event.record()
- else:
- if embedding_mode == "sentence-transformers" and self.device.type == "mps":
- torch.mps.synchronize()
- self.start_time = time.time()
-
- def end(self):
- if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
- self.end_event.record()
- torch.cuda.synchronize()
- else:
- if embedding_mode == "sentence-transformers" and self.device.type == "mps":
- torch.mps.synchronize()
- self.end_time = time.time()
-
- def elapsed_time(self):
- if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
- return self.start_event.elapsed_time(self.end_event) / 1000.0
- else:
- return self.end_time - self.start_time
-
- def print_elapsed(self):
- elapsed = self.elapsed_time()
- print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
-
- def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
- """Process text batch"""
- if not texts_batch:
- return np.array([])
-
- # Filter out empty texts and their corresponding IDs
- valid_texts = []
- valid_ids = []
- for i, text in enumerate(texts_batch):
- if text.strip(): # Only include non-empty texts
- valid_texts.append(text)
- valid_ids.append(ids_batch[i])
-
- if not valid_texts:
- print("WARNING: No valid texts in batch")
- return np.array([])
-
- # Tokenize
- token_timer = DeviceTimer("tokenization")
- with token_timer.timing():
- inputs = tokenizer(
- valid_texts,
- padding=True,
- truncation=True,
- max_length=512,
- return_tensors="pt"
- ).to(device)
-
- # Compute embeddings
- embed_timer = DeviceTimer("embedding computation")
- with embed_timer.timing():
- with torch.no_grad():
- outputs = model(**inputs)
- hidden_states = outputs.last_hidden_state
-
- # Mean pooling
- attention_mask = inputs['attention_mask']
- mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
- sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
- sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
- batch_embeddings = sum_embeddings / sum_mask
- embed_timer.print_elapsed()
-
- return batch_embeddings.cpu().numpy()
-
- # ZMQ server main loop - modified to use REP socket
- context = zmq.Context()
- socket = context.socket(zmq.ROUTER) # Changed to REP socket
- socket.bind(f"tcp://127.0.0.1:{zmq_port}")
- print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
-
- # Set timeouts
- socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
- socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
-
- from . import embedding_pb2
-
- print(f"INFO: Embedding server ready to serve requests")
-
- # Start warmup thread if enabled
- if enable_warmup and len(passages) > 0:
- import threading
- print(f"Warmup enabled: starting warmup thread")
- warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
- warmup_thread.daemon = True
- warmup_thread.start()
- else:
- print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
-
- while True:
- try:
- parts = socket.recv_multipart()
-
- # --- Restore robust message format detection ---
- # Must check parts length to avoid IndexError
- if len(parts) >= 3:
- identity = parts[0]
- # empty = parts[1] # We usually don't care about the middle empty frame
- message = parts[2]
- elif len(parts) == 2:
- # Can also handle cases without empty frame
- identity = parts[0]
- message = parts[1]
- else:
- # If received message format is wrong, print warning and ignore it instead of crashing
- print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
- continue
- print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
-
- # Handle control messages (MessagePack format)
- try:
- request_payload = msgpack.unpackb(message)
- if isinstance(request_payload, list) and len(request_payload) >= 1:
- if request_payload[0] == "__QUERY_META_PATH__":
- # Return the current meta path being used by the server
- current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
- response = [current_meta_path]
- socket.send_multipart([identity, b'', msgpack.packb(response)])
- continue
-
- elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
- # Update the server's meta path and reload passages
- new_meta_path = request_payload[1]
- try:
- print(f"INFO: Updating server meta path to: {new_meta_path}")
- # Reload passages from the new meta file
- passages = load_passages_from_metadata(new_meta_path)
- # Store the meta path for future queries
- passages._meta_path = new_meta_path
- response = ["SUCCESS"]
- print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
- except Exception as e:
- print(f"ERROR: Failed to update meta path: {e}")
- response = ["FAILED", str(e)]
- socket.send_multipart([identity, b'', msgpack.packb(response)])
- continue
-
- elif request_payload[0] == "__QUERY_MODEL__":
- # Return the current model being used by the server
- response = [model_name]
- socket.send_multipart([identity, b'', msgpack.packb(response)])
- continue
-
- elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
- # Update the server's embedding model
- new_model_name = request_payload[1]
- try:
- print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
-
- # Clean up old model to free memory
- if not use_mlx:
- print("INFO: Releasing old model from memory...")
- old_model = model
- old_tokenizer = tokenizer
-
- # Load new tokenizer first
- print(f"Loading new tokenizer for {new_model_name}...")
- tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
-
- # Load new model
- print(f"Loading new model {new_model_name}...")
- model = AutoModel.from_pretrained(new_model_name).to(device).eval()
-
- # Optimize new model
- if cuda_available or mps_available:
- try:
- model = model.half()
- model = torch.compile(model)
- print(f"INFO: Using FP16 precision with model: {new_model_name}")
- except Exception as e:
- print(f"WARNING: Model optimization failed: {e}")
-
- # Now safely delete old model after new one is loaded
- del old_model
- del old_tokenizer
-
- # Clear GPU cache if available
- if device.type == "cuda":
- torch.cuda.empty_cache()
- print("INFO: Cleared CUDA cache")
- elif device.type == "mps":
- torch.mps.empty_cache()
- print("INFO: Cleared MPS cache")
-
- # Force garbage collection
- import gc
- gc.collect()
- print("INFO: Memory cleanup completed")
-
- # Update model name
- model_name = new_model_name
-
- response = ["SUCCESS"]
- print(f"INFO: Successfully updated model to: {new_model_name}")
- except Exception as e:
- print(f"ERROR: Failed to update model: {e}")
- response = ["FAILED", str(e)]
- socket.send_multipart([identity, b'', msgpack.packb(response)])
- continue
- except:
- # Not a control message, continue with normal protobuf processing
- pass
-
- e2e_start = time.time()
- lookup_timer = DeviceTimer("text lookup")
-
- # Parse request
- req_proto = embedding_pb2.NodeEmbeddingRequest()
- req_proto.ParseFromString(message)
- node_ids = req_proto.node_ids
- print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
-
- # Add debug information
- if len(node_ids) > 0:
- print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
-
- # Look up texts
- texts = []
- missing_ids = []
- with lookup_timer.timing():
- for nid in node_ids:
- txtinfo = passages[nid]
- txt = txtinfo["text"]
- if txt:
- texts.append(txt)
- else:
- # If text is empty, we still need a placeholder for batch processing,
- # but record its ID as missing
- texts.append("")
- missing_ids.append(nid)
- lookup_timer.print_elapsed()
-
- if missing_ids:
- print(f"WARNING: Missing passages for IDs: {missing_ids}")
-
- # Process batch
- total_size = len(texts)
- print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
-
- all_embeddings = []
-
- if total_size > max_batch_size:
- print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
- for i in range(0, total_size, max_batch_size):
- end_idx = min(i + max_batch_size, total_size)
- print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
-
- chunk_texts = texts[i:end_idx]
- chunk_ids = node_ids[i:end_idx]
-
- if embedding_mode == "mlx":
- embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
- elif embedding_mode == "openai":
- embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
- else: # sentence-transformers
- embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
- all_embeddings.append(embeddings_chunk)
-
- if embedding_mode == "sentence-transformers":
- if cuda_available:
- torch.cuda.empty_cache()
- elif device.type == "mps":
- torch.mps.empty_cache()
-
- hidden = np.vstack(all_embeddings)
- print(f"INFO: Combined embeddings shape: {hidden.shape}")
- else:
- if embedding_mode == "mlx":
- hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
- elif embedding_mode == "openai":
- hidden = compute_embeddings_openai(texts, model_name)
- else: # sentence-transformers
- hidden = process_batch_pytorch(texts, node_ids, missing_ids)
-
- # Serialize response
- ser_start = time.time()
-
- resp_proto = embedding_pb2.NodeEmbeddingResponse()
- hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
- resp_proto.embeddings_data = hidden_contiguous.tobytes()
- resp_proto.dimensions.append(hidden_contiguous.shape[0])
- resp_proto.dimensions.append(hidden_contiguous.shape[1])
- resp_proto.missing_ids.extend(missing_ids)
-
- response_data = resp_proto.SerializeToString()
-
- # REP socket sends a single response
- socket.send_multipart([identity, b'', response_data])
-
- ser_end = time.time()
-
- print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
-
- if embedding_mode == "sentence-transformers":
- if device.type == "cuda":
- torch.cuda.synchronize()
- elif device.type == "mps":
- torch.mps.synchronize()
- e2e_end = time.time()
- print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
-
- except zmq.Again:
- print("INFO: ZMQ socket timeout, continuing to listen")
- continue
- except Exception as e:
- print(f"ERROR: Error in ZMQ server: {e}")
- try:
- # Send empty response to maintain REQ-REP state
- empty_resp = embedding_pb2.NodeEmbeddingResponse()
- socket.send(empty_resp.SerializeToString())
- except:
- # If sending fails, recreate socket
- socket.close()
- socket = context.socket(zmq.REP)
- socket.bind(f"tcp://127.0.0.1:{zmq_port}")
- socket.setsockopt(zmq.RCVTIMEO, 5000)
- socket.setsockopt(zmq.SNDTIMEO, 300000)
- print("INFO: ZMQ socket recreated after error")
-
- except Exception as e:
- print(f"ERROR: Failed to start embedding server: {e}")
- raise
-
-
-def create_embedding_server(
- domain="demo",
- load_passages=True,
- load_embeddings=False,
- use_fp16=True,
- use_int8=False,
- use_cuda_graphs=False,
- zmq_port=5555,
- max_batch_size=128,
- lazy_load_passages=False,
- model_name="sentence-transformers/all-mpnet-base-v2",
- passages_file: Optional[str] = None,
- embedding_mode: str = "sentence-transformers",
- enable_warmup: bool = False,
-):
- """
- εζη create_embedding_server ε½ζ°δΏζδΈε
- θΏδΈͺζ―ι»ε‘ηζ¬οΌη¨δΊη΄ζ₯θΏθ‘
- """
- create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Embedding service")
- parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
- parser.add_argument("--domain", type=str, default="demo", help="Domain name")
- parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
- parser.add_argument("--load-passages", action="store_true", default=True)
- parser.add_argument("--load-embeddings", action="store_true", default=False)
- parser.add_argument("--use-fp16", action="store_true", default=False)
- parser.add_argument("--use-int8", action="store_true", default=False)
- parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
- parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
- parser.add_argument("--lazy-load-passages", action="store_true", default=True)
- parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
- help="Embedding model name")
- parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
- choices=["sentence-transformers", "mlx", "openai"],
- help="Embedding backend mode")
- parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
- parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
- args = parser.parse_args()
-
- # Handle backward compatibility with use_mlx
- embedding_mode = args.embedding_mode
- if args.use_mlx:
- embedding_mode = "mlx"
-
- create_embedding_server(
- domain=args.domain,
- load_passages=args.load_passages,
- load_embeddings=args.load_embeddings,
- use_fp16=args.use_fp16,
- use_int8=args.use_int8,
- use_cuda_graphs=args.use_cuda_graphs,
- zmq_port=args.zmq_port,
- max_batch_size=args.max_batch_size,
- lazy_load_passages=args.lazy_load_passages,
- model_name=args.model_name,
- passages_file=args.passages_file,
- embedding_mode=embedding_mode,
- enable_warmup=not args.disable_warmup,
- )
diff --git a/packages/leann-backend-diskann/third_party/DiskANN b/packages/leann-backend-diskann/third_party/DiskANN
index af2a264..25339b0 160000
--- a/packages/leann-backend-diskann/third_party/DiskANN
+++ b/packages/leann-backend-diskann/third_party/DiskANN
@@ -1 +1 @@
-Subproject commit af2a26481e65232b57b82d96e68833cdee9f7635
+Subproject commit 25339b03413b5067c25b6092ea3e0f77ef8515c8
diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
index f1f8da0..0aa903e 100644
--- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
+++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
@@ -142,12 +142,12 @@ class HNSWSearcher(BaseSearcher):
self,
query: np.ndarray,
top_k: int,
+ zmq_port: Optional[int] = None,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
- expected_zmq_port: Optional[int] = None,
batch_size: int = 0,
**kwargs,
) -> Dict[str, Any]:
@@ -165,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
- expected_zmq_port: ZMQ port for embedding server
+ zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
@@ -177,6 +177,11 @@ class HNSWSearcher(BaseSearcher):
if not recompute_embeddings:
if self.is_pruned:
raise RuntimeError("Recompute is required for pruned index.")
+ if recompute_embeddings:
+ if zmq_port is None:
+ raise ValueError(
+ "zmq_port must be provided if recompute_embeddings is True"
+ )
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -184,7 +189,10 @@ class HNSWSearcher(BaseSearcher):
faiss.normalize_L2(query)
params = faiss.SearchParametersHNSW()
- params.zmq_port = expected_zmq_port
+ if zmq_port is not None:
+ params.zmq_port = (
+ zmq_port # C++ code won't use this if recompute_embeddings is False
+ )
params.efSearch = complexity
params.beam_size = beam_width
diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py
index cbee3f7..111a52b 100644
--- a/packages/leann-core/src/leann/api.py
+++ b/packages/leann-core/src/leann/api.py
@@ -450,7 +450,7 @@ class LeannSearcher:
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
- expected_zmq_port=zmq_port,
+ zmq_port=zmq_port,
**kwargs,
)
search_time = time.time() - start_time
diff --git a/packages/leann-core/src/leann/cli.py b/packages/leann-core/src/leann/cli.py
index 854265b..0b830b8 100644
--- a/packages/leann-core/src/leann/cli.py
+++ b/packages/leann-core/src/leann/cli.py
@@ -1,10 +1,6 @@
-#!/usr/bin/env python3
import argparse
import asyncio
-import sys
from pathlib import Path
-from typing import Optional
-import os
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
@@ -16,20 +12,20 @@ class LeannCLI:
def __init__(self):
self.indexes_dir = Path.home() / ".leann" / "indexes"
self.indexes_dir.mkdir(parents=True, exist_ok=True)
-
+
self.node_parser = SentenceSplitter(
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
)
-
+
def get_index_path(self, index_name: str) -> str:
index_dir = self.indexes_dir / index_name
return str(index_dir / "documents.leann")
-
+
def index_exists(self, index_name: str) -> bool:
index_dir = self.indexes_dir / index_name
meta_file = index_dir / "documents.leann.meta.json"
return meta_file.exists()
-
+
def create_parser(self) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="leann",
@@ -41,24 +37,32 @@ Examples:
leann search my-docs "query" # Search in my-docs index
leann ask my-docs "question" # Ask my-docs index
leann list # List all stored indexes
- """
+ """,
)
-
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
-
+
# Build command
build_parser = subparsers.add_parser("build", help="Build document index")
build_parser.add_argument("index_name", help="Index name")
- build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
- build_parser.add_argument("--backend", type=str, default="hnsw", choices=["hnsw", "diskann"])
- build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
- build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
+ build_parser.add_argument(
+ "--docs", type=str, required=True, help="Documents directory"
+ )
+ build_parser.add_argument(
+ "--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
+ )
+ build_parser.add_argument(
+ "--embedding-model", type=str, default="facebook/contriever"
+ )
+ build_parser.add_argument(
+ "--force", "-f", action="store_true", help="Force rebuild"
+ )
build_parser.add_argument("--graph-degree", type=int, default=32)
build_parser.add_argument("--complexity", type=int, default=64)
build_parser.add_argument("--num-threads", type=int, default=1)
build_parser.add_argument("--compact", action="store_true", default=True)
build_parser.add_argument("--recompute", action="store_true", default=True)
-
+
# Search command
search_parser = subparsers.add_parser("search", help="Search documents")
search_parser.add_argument("index_name", help="Index name")
@@ -68,12 +72,21 @@ Examples:
search_parser.add_argument("--beam-width", type=int, default=1)
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
search_parser.add_argument("--recompute-embeddings", action="store_true")
- search_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
-
+ search_parser.add_argument(
+ "--pruning-strategy",
+ choices=["global", "local", "proportional"],
+ default="global",
+ )
+
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name")
- ask_parser.add_argument("--llm", type=str, default="ollama", choices=["simulated", "ollama", "hf", "openai"])
+ ask_parser.add_argument(
+ "--llm",
+ type=str,
+ default="ollama",
+ choices=["simulated", "ollama", "hf", "openai"],
+ )
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
ask_parser.add_argument("--interactive", "-i", action="store_true")
@@ -82,81 +95,91 @@ Examples:
ask_parser.add_argument("--beam-width", type=int, default=1)
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
ask_parser.add_argument("--recompute-embeddings", action="store_true")
- ask_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
-
+ ask_parser.add_argument(
+ "--pruning-strategy",
+ choices=["global", "local", "proportional"],
+ default="global",
+ )
+
# List command
list_parser = subparsers.add_parser("list", help="List all indexes")
-
+
return parser
-
+
def list_indexes(self):
print("Stored LEANN indexes:")
-
+
if not self.indexes_dir.exists():
- print("No indexes found. Use 'leann build --docs ' to create one.")
+ print(
+ "No indexes found. Use 'leann build --docs ' to create one."
+ )
return
-
+
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
-
+
if not index_dirs:
- print("No indexes found. Use 'leann build --docs ' to create one.")
+ print(
+ "No indexes found. Use 'leann build --docs ' to create one."
+ )
return
-
+
print(f"Found {len(index_dirs)} indexes:")
for i, index_dir in enumerate(index_dirs, 1):
index_name = index_dir.name
status = "β" if self.index_exists(index_name) else "β"
-
+
print(f" {i}. {index_name} [{status}]")
if self.index_exists(index_name):
meta_file = index_dir / "documents.leann.meta.json"
- size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (1024 * 1024)
+ size_mb = sum(
+ f.stat().st_size for f in index_dir.iterdir() if f.is_file()
+ ) / (1024 * 1024)
print(f" Size: {size_mb:.1f} MB")
-
+
if index_dirs:
example_name = index_dirs[0].name
print(f"\nUsage:")
- print(f" leann search {example_name} \"your query\"")
+ print(f' leann search {example_name} "your query"')
print(f" leann ask {example_name} --interactive")
-
+
def load_documents(self, docs_dir: str):
print(f"Loading documents from {docs_dir}...")
-
+
documents = SimpleDirectoryReader(
docs_dir,
recursive=True,
encoding="utf-8",
required_exts=[".pdf", ".txt", ".md", ".docx"],
).load_data(show_progress=True)
-
+
all_texts = []
for doc in documents:
nodes = self.node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.get_content())
-
+
print(f"Loaded {len(documents)} documents, {len(all_texts)} chunks")
return all_texts
-
+
async def build_index(self, args):
docs_dir = args.docs
index_name = args.index_name
index_dir = self.indexes_dir / index_name
index_path = self.get_index_path(index_name)
-
+
if index_dir.exists() and not args.force:
print(f"Index '{index_name}' already exists. Use --force to rebuild.")
return
-
+
all_texts = self.load_documents(docs_dir)
if not all_texts:
print("No documents found")
return
-
+
index_dir.mkdir(parents=True, exist_ok=True)
-
+
print(f"Building index '{index_name}' with {args.backend} backend...")
-
+
builder = LeannBuilder(
backend_name=args.backend,
embedding_model=args.embedding_model,
@@ -166,103 +189,107 @@ Examples:
is_recompute=args.recompute,
num_threads=args.num_threads,
)
-
+
for chunk_text in all_texts:
builder.add_text(chunk_text)
-
+
builder.build_index(index_path)
print(f"Index built at {index_path}")
-
+
async def search_documents(self, args):
index_name = args.index_name
query = args.query
index_path = self.get_index_path(index_name)
-
+
if not self.index_exists(index_name):
- print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs ' to create it.")
+ print(
+ f"Index '{index_name}' not found. Use 'leann build {index_name} --docs ' to create it."
+ )
return
-
+
searcher = LeannSearcher(index_path=index_path)
results = searcher.search(
- query,
+ query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
- pruning_strategy=args.pruning_strategy
+ pruning_strategy=args.pruning_strategy,
)
-
+
print(f"Search results for '{query}' (top {len(results)}):")
for i, result in enumerate(results, 1):
print(f"{i}. Score: {result.score:.3f}")
print(f" {result.text[:200]}...")
print()
-
+
async def ask_questions(self, args):
index_name = args.index_name
index_path = self.get_index_path(index_name)
-
+
if not self.index_exists(index_name):
- print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs ' to create it.")
+ print(
+ f"Index '{index_name}' not found. Use 'leann build {index_name} --docs ' to create it."
+ )
return
-
+
print(f"Starting chat with index '{index_name}'...")
print(f"Using {args.model} ({args.llm})")
-
+
llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama":
llm_config["host"] = args.host
-
+
chat = LeannChat(index_path=index_path, llm_config=llm_config)
-
+
if args.interactive:
print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40)
-
+
while True:
user_input = input("\nYou: ").strip()
- if user_input.lower() in ['quit', 'exit', 'q']:
+ if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
-
+
if not user_input:
continue
-
+
response = chat.ask(
- user_input,
+ user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
- pruning_strategy=args.pruning_strategy
+ pruning_strategy=args.pruning_strategy,
)
print(f"LEANN: {response}")
else:
query = input("Enter your question: ").strip()
if query:
response = chat.ask(
- query,
+ query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
- pruning_strategy=args.pruning_strategy
+ pruning_strategy=args.pruning_strategy,
)
print(f"LEANN: {response}")
-
+
async def run(self, args=None):
parser = self.create_parser()
-
+
if args is None:
args = parser.parse_args()
-
+
if not args.command:
parser.print_help()
return
-
+
if args.command == "list":
self.list_indexes()
elif args.command == "build":
@@ -277,11 +304,12 @@ Examples:
def main():
import dotenv
+
dotenv.load_dotenv()
-
+
cli = LeannCLI()
asyncio.run(cli.run())
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py
index ee2a59d..9f6be79 100644
--- a/packages/leann-core/src/leann/embedding_compute.py
+++ b/packages/leann-core/src/leann/embedding_compute.py
@@ -60,6 +60,9 @@ def compute_embeddings_sentence_transformers(
"""
Compute embeddings using SentenceTransformer with model caching
"""
+ # Handle empty input
+ if not texts:
+ raise ValueError("Cannot compute embeddings for empty text list")
logger.info(
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
)
diff --git a/packages/leann-core/src/leann/interface.py b/packages/leann-core/src/leann/interface.py
index 338c3dc..93a6ce8 100644
--- a/packages/leann-core/src/leann/interface.py
+++ b/packages/leann-core/src/leann/interface.py
@@ -64,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
- zmq_port: ZMQ port for embedding server communication
+ zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
**kwargs: Backend-specific parameters
Returns:
diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py
index 8592ed4..6bd6ec8 100644
--- a/packages/leann-core/src/leann/searcher_base.py
+++ b/packages/leann-core/src/leann/searcher_base.py
@@ -104,6 +104,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
# Try to use embedding server if available and requested
if use_server_if_available:
try:
+ # TODO: Maybe we can directly use this port here?
+ # For this internal method, it's ok to assume that the server is running
+ # on that port?
+
# Ensure we have a server with passages_file for compatibility
passages_source_file = (
self.index_dir / f"{self.index_path.name}.meta.json"
@@ -181,7 +185,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
- zmq_port: ZMQ port for embedding server communication
+ zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
Returns: