Merge branch 'main' of https://github.com/yichuan-w/LEANN
This commit is contained in:
65
README.md
65
README.md
@@ -292,6 +292,71 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## 🖥️ Command Line Interface
|
||||||
|
|
||||||
|
LEANN includes a powerful CLI for document processing and search. Perfect for quick document indexing and interactive chat.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index from documents
|
||||||
|
leann build my-docs --docs ./documents
|
||||||
|
|
||||||
|
# Search your documents
|
||||||
|
leann search my-docs "machine learning concepts"
|
||||||
|
|
||||||
|
# Interactive chat with your documents
|
||||||
|
leann ask my-docs --interactive
|
||||||
|
|
||||||
|
# List all your indexes
|
||||||
|
leann list
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key CLI features:**
|
||||||
|
- Auto-detects document formats (PDF, TXT, MD, DOCX)
|
||||||
|
- Smart text chunking with overlap
|
||||||
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
|
- Organized index storage in `~/.leann/indexes/`
|
||||||
|
- Support for advanced search parameters
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>📋 Click to expand: Complete CLI Reference</strong></summary>
|
||||||
|
|
||||||
|
**Build Command:**
|
||||||
|
```bash
|
||||||
|
leann build INDEX_NAME --docs DIRECTORY [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--backend {hnsw,diskann} Backend to use (default: hnsw)
|
||||||
|
--embedding-model MODEL Embedding model (default: facebook/contriever)
|
||||||
|
--graph-degree N Graph degree (default: 32)
|
||||||
|
--complexity N Build complexity (default: 64)
|
||||||
|
--force Force rebuild existing index
|
||||||
|
--compact Use compact storage (default: true)
|
||||||
|
--recompute Enable recomputation (default: true)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Search Command:**
|
||||||
|
```bash
|
||||||
|
leann search INDEX_NAME QUERY [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--top-k N Number of results (default: 5)
|
||||||
|
--complexity N Search complexity (default: 64)
|
||||||
|
--recompute-embeddings Use recomputation for highest accuracy
|
||||||
|
--pruning-strategy {global,local,proportional}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ask Command:**
|
||||||
|
```bash
|
||||||
|
leann ask INDEX_NAME [OPTIONS]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--llm {ollama,openai,hf} LLM provider (default: ollama)
|
||||||
|
--model MODEL Model name (default: qwen3:8b)
|
||||||
|
--interactive Interactive chat mode
|
||||||
|
--top-k N Retrieval count (default: 20)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## 🏗️ Architecture & How It Works
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, List, Literal
|
from typing import Dict, Any, List, Literal, Optional
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
from leann.interface import (
|
from leann.interface import (
|
||||||
@@ -13,6 +16,46 @@ from leann.interface import (
|
|||||||
LeannBackendSearcherInterface,
|
LeannBackendSearcherInterface,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def suppress_cpp_output_if_needed():
|
||||||
|
"""Suppress C++ stdout/stderr based on LEANN_LOG_LEVEL"""
|
||||||
|
log_level = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
|
||||||
|
# Only suppress if log level is WARNING or higher (ERROR, CRITICAL)
|
||||||
|
should_suppress = log_level in ["WARNING", "ERROR", "CRITICAL"]
|
||||||
|
|
||||||
|
if not should_suppress:
|
||||||
|
# Don't suppress, just yield
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save original file descriptors
|
||||||
|
stdout_fd = sys.stdout.fileno()
|
||||||
|
stderr_fd = sys.stderr.fileno()
|
||||||
|
|
||||||
|
# Save original stdout/stderr
|
||||||
|
stdout_dup = os.dup(stdout_fd)
|
||||||
|
stderr_dup = os.dup(stderr_fd)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Redirect to /dev/null
|
||||||
|
devnull = os.open(os.devnull, os.O_WRONLY)
|
||||||
|
os.dup2(devnull, stdout_fd)
|
||||||
|
os.dup2(devnull, stderr_fd)
|
||||||
|
os.close(devnull)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original file descriptors
|
||||||
|
os.dup2(stdout_dup, stdout_fd)
|
||||||
|
os.dup2(stderr_dup, stderr_fd)
|
||||||
|
os.close(stdout_dup)
|
||||||
|
os.close(stderr_dup)
|
||||||
|
|
||||||
|
|
||||||
def _get_diskann_metrics():
|
def _get_diskann_metrics():
|
||||||
from . import _diskannpy as diskannpy # type: ignore
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
@@ -64,6 +107,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
index_dir.mkdir(parents=True, exist_ok=True)
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if data.dtype != np.float32:
|
if data.dtype != np.float32:
|
||||||
|
logger.warning(f"Converting data to float32, shape: {data.shape}")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
data_filename = f"{index_prefix}_data.bin"
|
data_filename = f"{index_prefix}_data.bin"
|
||||||
@@ -74,7 +118,9 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
build_kwargs.get("distance_metric", "mips").lower()
|
build_kwargs.get("distance_metric", "mips").lower()
|
||||||
)
|
)
|
||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError("Unsupported distance_metric.")
|
raise ValueError(
|
||||||
|
f"Unsupported distance_metric '{build_kwargs.get('distance_metric', 'unknown')}'."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from . import _diskannpy as diskannpy # type: ignore
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
@@ -96,36 +142,40 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
|||||||
temp_data_file = index_dir / data_filename
|
temp_data_file = index_dir / data_filename
|
||||||
if temp_data_file.exists():
|
if temp_data_file.exists():
|
||||||
os.remove(temp_data_file)
|
os.remove(temp_data_file)
|
||||||
|
logger.debug(f"Cleaned up temporary data file: {temp_data_file}")
|
||||||
|
|
||||||
|
|
||||||
class DiskannSearcher(BaseSearcher):
|
class DiskannSearcher(BaseSearcher):
|
||||||
def __init__(self, index_path: str, **kwargs):
|
def __init__(self, index_path: str, **kwargs):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
index_path,
|
index_path,
|
||||||
backend_module_name="leann_backend_diskann.embedding_server",
|
backend_module_name="leann_backend_diskann.diskann_embedding_server",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
from . import _diskannpy as diskannpy # type: ignore
|
|
||||||
|
|
||||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
# Initialize DiskANN index with suppressed C++ output based on log level
|
||||||
metric_enum = _get_diskann_metrics().get(distance_metric)
|
with suppress_cpp_output_if_needed():
|
||||||
if metric_enum is None:
|
from . import _diskannpy as diskannpy # type: ignore
|
||||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
|
||||||
|
|
||||||
self.num_threads = kwargs.get("num_threads", 8)
|
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||||
self.zmq_port = kwargs.get("zmq_port", 6666)
|
metric_enum = _get_diskann_metrics().get(distance_metric)
|
||||||
|
if metric_enum is None:
|
||||||
|
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||||
|
|
||||||
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
self.num_threads = kwargs.get("num_threads", 8)
|
||||||
self._index = diskannpy.StaticDiskFloatIndex(
|
|
||||||
metric_enum,
|
fake_zmq_port = 6666
|
||||||
full_index_prefix,
|
full_index_prefix = str(self.index_dir / self.index_path.stem)
|
||||||
self.num_threads,
|
self._index = diskannpy.StaticDiskFloatIndex(
|
||||||
kwargs.get("num_nodes_to_cache", 0),
|
metric_enum,
|
||||||
1,
|
full_index_prefix,
|
||||||
self.zmq_port,
|
self.num_threads,
|
||||||
"",
|
kwargs.get("num_nodes_to_cache", 0),
|
||||||
"",
|
1,
|
||||||
)
|
fake_zmq_port, # Initial port, can be updated at runtime
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -136,7 +186,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = False,
|
recompute_embeddings: bool = False,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
zmq_port: int = 5557,
|
zmq_port: Optional[int] = None,
|
||||||
batch_recompute: bool = False,
|
batch_recompute: bool = False,
|
||||||
dedup_node_dis: bool = False,
|
dedup_node_dis: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -155,7 +205,7 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
- "global": Use global pruning strategy (default)
|
- "global": Use global pruning strategy (default)
|
||||||
- "local": Use local pruning strategy
|
- "local": Use local pruning strategy
|
||||||
- "proportional": Not supported in DiskANN, falls back to global
|
- "proportional": Not supported in DiskANN, falls back to global
|
||||||
zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
batch_recompute: Whether to batch neighbor recomputation (DiskANN-specific)
|
||||||
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
dedup_node_dis: Whether to cache and reuse distance computations (DiskANN-specific)
|
||||||
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
**kwargs: Additional DiskANN-specific parameters (for legacy compatibility)
|
||||||
@@ -163,22 +213,25 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
Returns:
|
Returns:
|
||||||
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
Dict with 'labels' (list of lists) and 'distances' (ndarray)
|
||||||
"""
|
"""
|
||||||
|
# Handle zmq_port compatibility: DiskANN can now update port at runtime
|
||||||
|
if recompute_embeddings:
|
||||||
|
if zmq_port is None:
|
||||||
|
raise ValueError(
|
||||||
|
"zmq_port must be provided if recompute_embeddings is True"
|
||||||
|
)
|
||||||
|
current_port = self._index.get_zmq_port()
|
||||||
|
if zmq_port != current_port:
|
||||||
|
logger.debug(
|
||||||
|
f"Updating DiskANN zmq_port from {current_port} to {zmq_port}"
|
||||||
|
)
|
||||||
|
self._index.set_zmq_port(zmq_port)
|
||||||
|
|
||||||
# DiskANN doesn't support "proportional" strategy
|
# DiskANN doesn't support "proportional" strategy
|
||||||
if pruning_strategy == "proportional":
|
if pruning_strategy == "proportional":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
"DiskANN backend does not support 'proportional' pruning strategy. Use 'global' or 'local' instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use recompute_embeddings parameter
|
|
||||||
use_recompute = recompute_embeddings
|
|
||||||
if use_recompute:
|
|
||||||
meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json"
|
|
||||||
if not meta_file_path.exists():
|
|
||||||
raise RuntimeError(
|
|
||||||
f"FATAL: Recompute enabled but metadata file not found: {meta_file_path}"
|
|
||||||
)
|
|
||||||
self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs)
|
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
|
|
||||||
@@ -188,21 +241,23 @@ class DiskannSearcher(BaseSearcher):
|
|||||||
else: # "global"
|
else: # "global"
|
||||||
use_global_pruning = True
|
use_global_pruning = True
|
||||||
|
|
||||||
labels, distances = self._index.batch_search(
|
# Perform search with suppressed C++ output based on log level
|
||||||
query,
|
with suppress_cpp_output_if_needed():
|
||||||
query.shape[0],
|
labels, distances = self._index.batch_search(
|
||||||
top_k,
|
query,
|
||||||
complexity,
|
query.shape[0],
|
||||||
beam_width,
|
top_k,
|
||||||
self.num_threads,
|
complexity,
|
||||||
kwargs.get("USE_DEFERRED_FETCH", False),
|
beam_width,
|
||||||
kwargs.get("skip_search_reorder", False),
|
self.num_threads,
|
||||||
use_recompute,
|
kwargs.get("USE_DEFERRED_FETCH", False),
|
||||||
dedup_node_dis,
|
kwargs.get("skip_search_reorder", False),
|
||||||
prune_ratio,
|
recompute_embeddings,
|
||||||
batch_recompute,
|
dedup_node_dis,
|
||||||
use_global_pruning,
|
prune_ratio,
|
||||||
)
|
batch_recompute,
|
||||||
|
use_global_pruning,
|
||||||
|
)
|
||||||
|
|
||||||
string_labels = [
|
string_labels = [
|
||||||
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
[str(int_label) for int_label in batch_labels] for batch_labels in labels
|
||||||
|
|||||||
@@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
DiskANN-specific embedding server
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import zmq
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Force set logger level (don't rely on basicConfig in subprocess)
|
||||||
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
# Ensure we have a handler if none exists
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def create_diskann_embedding_server(
|
||||||
|
passages_file: Optional[str] = None,
|
||||||
|
zmq_port: int = 5555,
|
||||||
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create and start a ZMQ-based embedding server for DiskANN backend.
|
||||||
|
Uses ROUTER socket and protobuf communication as required by DiskANN C++ implementation.
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting DiskANN server on port {zmq_port} with model {model_name}")
|
||||||
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||||
|
|
||||||
|
# Add leann-core to path for unified embedding computation
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
||||||
|
sys.path.insert(0, str(leann_core_path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.api import PassageManager
|
||||||
|
|
||||||
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import embedding computation module: {e}")
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
sys.path.pop(0)
|
||||||
|
|
||||||
|
# Check port availability
|
||||||
|
import socket
|
||||||
|
|
||||||
|
def check_port(port):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
|
if check_port(zmq_port):
|
||||||
|
logger.error(f"Port {zmq_port} is already in use")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Only support metadata file, fail fast for everything else
|
||||||
|
if not passages_file or not passages_file.endswith(".meta.json"):
|
||||||
|
raise ValueError("Only metadata files (.meta.json) are supported")
|
||||||
|
|
||||||
|
# Load metadata to get passage sources
|
||||||
|
with open(passages_file, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passages = PassageManager(meta["passage_sources"])
|
||||||
|
logger.info(
|
||||||
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import protobuf after ensuring the path is correct
|
||||||
|
try:
|
||||||
|
from . import embedding_pb2
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Failed to import protobuf module: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
def zmq_server_thread():
|
||||||
|
"""ZMQ server thread using REP socket for universal compatibility"""
|
||||||
|
context = zmq.Context()
|
||||||
|
socket = context.socket(zmq.REP) # REP socket for both BaseSearcher and DiskANN C++ REQ clients
|
||||||
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# REP socket receives single-part messages
|
||||||
|
message = socket.recv()
|
||||||
|
|
||||||
|
# Check for empty messages - REP socket requires response to every request
|
||||||
|
if len(message) == 0:
|
||||||
|
logger.debug("Received empty message, sending empty response")
|
||||||
|
socket.send(b"") # REP socket must respond to every request
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.debug(f"Received ZMQ request of size {len(message)} bytes")
|
||||||
|
logger.debug(f"Message preview: {message[:50]}") # Show first 50 bytes
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
|
||||||
|
# Try protobuf first (for DiskANN C++ node_ids requests - primary use case)
|
||||||
|
texts = []
|
||||||
|
node_ids = []
|
||||||
|
is_text_request = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
|
req_proto.ParseFromString(message)
|
||||||
|
node_ids = list(req_proto.node_ids)
|
||||||
|
|
||||||
|
if not node_ids:
|
||||||
|
raise RuntimeError(f"PROTOBUF: Received empty node_ids! Message size: {len(message)}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"✅ PROTOBUF: Node ID request for {len(node_ids)} node embeddings: {node_ids[:10]}"
|
||||||
|
)
|
||||||
|
except Exception as protobuf_error:
|
||||||
|
logger.debug(f"Protobuf parsing failed: {protobuf_error}")
|
||||||
|
# Fallback to msgpack (for BaseSearcher direct text requests)
|
||||||
|
try:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
request = msgpack.unpackb(message)
|
||||||
|
# For BaseSearcher compatibility, request is a list of texts directly
|
||||||
|
if isinstance(request, list) and all(
|
||||||
|
isinstance(item, str) for item in request
|
||||||
|
):
|
||||||
|
texts = request
|
||||||
|
is_text_request = True
|
||||||
|
logger.info(
|
||||||
|
f"✅ MSGPACK: Direct text request for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a valid msgpack text request")
|
||||||
|
except Exception as msgpack_error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Both protobuf and msgpack parsing failed! Protobuf: {protobuf_error}, Msgpack: {msgpack_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look up texts by node IDs (only if not direct text request)
|
||||||
|
if not is_text_request:
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"FATAL: Empty text for passage ID {nid}"
|
||||||
|
)
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError as e:
|
||||||
|
logger.error(f"Passage ID {nid} not found: {e}")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Debug logging
|
||||||
|
logger.debug(
|
||||||
|
f"Processing {len(texts)} texts"
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Text lengths: {[len(t) for t in texts[:5]]}"
|
||||||
|
) # Show first 5
|
||||||
|
|
||||||
|
# Process embeddings using unified computation
|
||||||
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare response based on request type
|
||||||
|
if is_text_request:
|
||||||
|
# For BaseSearcher compatibility: return msgpack format
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb(embeddings.tolist())
|
||||||
|
else:
|
||||||
|
# For DiskANN C++ compatibility: return protobuf format
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
hidden_contiguous = np.ascontiguousarray(
|
||||||
|
embeddings, dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Serialize embeddings data
|
||||||
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||||
|
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
|
||||||
|
# Send response back to the client
|
||||||
|
socket.send(response_data)
|
||||||
|
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
|
zmq_thread.start()
|
||||||
|
logger.info(f"Started DiskANN ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
|
# Keep the main thread alive
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("DiskANN Server shutting down...")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="DiskANN Embedding service")
|
||||||
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||||
|
parser.add_argument(
|
||||||
|
"--passages-file",
|
||||||
|
type=str,
|
||||||
|
help="Metadata JSON file containing passage sources",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create and start the DiskANN embedding server
|
||||||
|
create_diskann_embedding_server(
|
||||||
|
passages_file=args.passages_file,
|
||||||
|
zmq_port=args.zmq_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
)
|
||||||
@@ -1,705 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Embedding server for leann-backend-diskann - Fixed ZMQ REQ-REP pattern
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
from typing import Dict, Any, Optional, Union
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer, AutoModel
|
|
||||||
import os
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import zmq
|
|
||||||
import numpy as np
|
|
||||||
import msgpack
|
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
|
||||||
|
|
||||||
RED = "\033[91m"
|
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
|
||||||
LOG_LEVEL = os.getenv('LEANN_LOG_LEVEL', 'INFO').upper()
|
|
||||||
logging.basicConfig(
|
|
||||||
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
# --- New Passage Loader from HNSW backend ---
|
|
||||||
class SimplePassageLoader:
|
|
||||||
"""
|
|
||||||
Simple passage loader that replaces config.py dependencies
|
|
||||||
"""
|
|
||||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
|
||||||
self.passages_data = passages_data or {}
|
|
||||||
self._meta_path = ''
|
|
||||||
|
|
||||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
|
||||||
"""Get passage by ID"""
|
|
||||||
str_id = str(passage_id)
|
|
||||||
if str_id in self.passages_data:
|
|
||||||
return {"text": self.passages_data[str_id]}
|
|
||||||
else:
|
|
||||||
# Return empty text for missing passages
|
|
||||||
return {"text": ""}
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.passages_data)
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.passages_data.keys()
|
|
||||||
|
|
||||||
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|
||||||
"""
|
|
||||||
Load passages using metadata file with PassageManager for lazy loading
|
|
||||||
"""
|
|
||||||
# Load metadata to get passage sources
|
|
||||||
with open(meta_file, 'r') as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
|
|
||||||
# Import PassageManager dynamically to avoid circular imports
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Find the leann package directory relative to this file
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
|
||||||
sys.path.insert(0, str(leann_core_path))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from leann.api import PassageManager
|
|
||||||
passage_manager = PassageManager(meta['passage_sources'])
|
|
||||||
finally:
|
|
||||||
sys.path.pop(0)
|
|
||||||
|
|
||||||
print(f"Initialized lazy passage loading for {len(passage_manager.global_offset_map)} passages")
|
|
||||||
|
|
||||||
class LazyPassageLoader(SimplePassageLoader):
|
|
||||||
def __init__(self, passage_manager):
|
|
||||||
self.passage_manager = passage_manager
|
|
||||||
# Initialize parent with empty data
|
|
||||||
super().__init__({})
|
|
||||||
|
|
||||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
|
||||||
"""Get passage by ID with lazy loading"""
|
|
||||||
try:
|
|
||||||
int_id = int(passage_id)
|
|
||||||
string_id = str(int_id)
|
|
||||||
passage_data = self.passage_manager.get_passage(string_id)
|
|
||||||
if passage_data and passage_data.get("text"):
|
|
||||||
return {"text": passage_data["text"]}
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.passage_manager.global_offset_map)
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.passage_manager.global_offset_map.keys()
|
|
||||||
|
|
||||||
loader = LazyPassageLoader(passage_manager)
|
|
||||||
loader._meta_path = meta_file
|
|
||||||
return loader
|
|
||||||
|
|
||||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
|
||||||
"""
|
|
||||||
Load passages from a JSONL file with label map support
|
|
||||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not os.path.exists(passages_file):
|
|
||||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
|
||||||
|
|
||||||
if not passages_file.endswith('.jsonl'):
|
|
||||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
|
||||||
|
|
||||||
# Load passages directly by their sequential IDs
|
|
||||||
passages_data = {}
|
|
||||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
|
||||||
for line in f:
|
|
||||||
if line.strip():
|
|
||||||
passage = json.loads(line)
|
|
||||||
passages_data[passage['id']] = passage['text']
|
|
||||||
|
|
||||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file}")
|
|
||||||
return SimplePassageLoader(passages_data)
|
|
||||||
|
|
||||||
def create_embedding_server_thread(
|
|
||||||
zmq_port=5555,
|
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
max_batch_size=128,
|
|
||||||
passages_file: Optional[str] = None,
|
|
||||||
embedding_mode: str = "sentence-transformers",
|
|
||||||
enable_warmup: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create and run embedding server in the current thread
|
|
||||||
This function is designed to be called in a separate thread
|
|
||||||
"""
|
|
||||||
logger.info(f"Initializing embedding server thread on port {zmq_port}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if port is already occupied
|
|
||||||
import socket
|
|
||||||
def check_port(port):
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
||||||
return s.connect_ex(('localhost', port)) == 0
|
|
||||||
|
|
||||||
if check_port(zmq_port):
|
|
||||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Auto-detect mode based on model name if not explicitly set
|
|
||||||
if embedding_mode == "sentence-transformers" and model_name.startswith("text-embedding-"):
|
|
||||||
embedding_mode = "openai"
|
|
||||||
|
|
||||||
if embedding_mode == "mlx":
|
|
||||||
from leann.api import compute_embeddings_mlx
|
|
||||||
import torch
|
|
||||||
logger.info("Using MLX for embeddings")
|
|
||||||
# Set device to CPU for compatibility with DeviceTimer class
|
|
||||||
device = torch.device("cpu")
|
|
||||||
cuda_available = False
|
|
||||||
mps_available = False
|
|
||||||
elif embedding_mode == "openai":
|
|
||||||
from leann.api import compute_embeddings_openai
|
|
||||||
import torch
|
|
||||||
logger.info("Using OpenAI API for embeddings")
|
|
||||||
# Set device to CPU for compatibility with DeviceTimer class
|
|
||||||
device = torch.device("cpu")
|
|
||||||
cuda_available = False
|
|
||||||
mps_available = False
|
|
||||||
elif embedding_mode == "sentence-transformers":
|
|
||||||
# Initialize model
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Select device
|
|
||||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
|
||||||
cuda_available = torch.cuda.is_available()
|
|
||||||
|
|
||||||
if cuda_available:
|
|
||||||
device = torch.device("cuda")
|
|
||||||
logger.info("Using CUDA device")
|
|
||||||
elif mps_available:
|
|
||||||
device = torch.device("mps")
|
|
||||||
logger.info("Using MPS device (Apple Silicon)")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
logger.info("Using CPU device")
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
logger.info(f"Loading model {model_name}")
|
|
||||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
|
||||||
|
|
||||||
# Optimize model
|
|
||||||
if cuda_available or mps_available:
|
|
||||||
try:
|
|
||||||
model = model.half()
|
|
||||||
model = torch.compile(model)
|
|
||||||
logger.info(f"Using FP16 precision with model: {model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported embedding mode: {embedding_mode}. Supported modes: sentence-transformers, mlx, openai")
|
|
||||||
|
|
||||||
# Load passages from file if provided
|
|
||||||
if passages_file and os.path.exists(passages_file):
|
|
||||||
# Check if it's a metadata file or a single passages file
|
|
||||||
if passages_file.endswith('.meta.json'):
|
|
||||||
passages = load_passages_from_metadata(passages_file)
|
|
||||||
else:
|
|
||||||
# Try to find metadata file in same directory
|
|
||||||
passages_dir = Path(passages_file).parent
|
|
||||||
meta_files = list(passages_dir.glob("*.meta.json"))
|
|
||||||
if meta_files:
|
|
||||||
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
|
|
||||||
passages = load_passages_from_metadata(str(meta_files[0]))
|
|
||||||
else:
|
|
||||||
# Fallback to original single file loading (will cause warnings)
|
|
||||||
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
|
|
||||||
passages = load_passages_from_file(passages_file)
|
|
||||||
else:
|
|
||||||
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
|
|
||||||
passages = SimplePassageLoader()
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(passages)} passages.")
|
|
||||||
|
|
||||||
def client_warmup(zmq_port):
|
|
||||||
"""Perform client-side warmup for DiskANN server"""
|
|
||||||
time.sleep(2)
|
|
||||||
print(f"Performing client-side warmup with model {model_name}...")
|
|
||||||
|
|
||||||
# Get actual passage IDs from the loaded passages
|
|
||||||
sample_ids = []
|
|
||||||
if hasattr(passages, 'keys') and len(passages) > 0:
|
|
||||||
available_ids = list(passages.keys())
|
|
||||||
# Take up to 5 actual IDs, but at least 1
|
|
||||||
sample_ids = available_ids[:min(5, len(available_ids))]
|
|
||||||
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
|
||||||
else:
|
|
||||||
print("No passages available for warmup, skipping warmup...")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.REQ)
|
|
||||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
|
||||||
|
|
||||||
try:
|
|
||||||
ids_to_send = [int(x) for x in sample_ids]
|
|
||||||
except ValueError:
|
|
||||||
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not ids_to_send:
|
|
||||||
print("Skipping warmup send.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Use protobuf format for warmup
|
|
||||||
from . import embedding_pb2
|
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
||||||
req_proto.node_ids.extend(ids_to_send)
|
|
||||||
request_bytes = req_proto.SerializeToString()
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
print(f"Sending warmup request {i + 1}/3 via ZMQ (Protobuf)...")
|
|
||||||
socket.send(request_bytes)
|
|
||||||
response_bytes = socket.recv()
|
|
||||||
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
resp_proto.ParseFromString(response_bytes)
|
|
||||||
embeddings_count = resp_proto.dimensions[0] if resp_proto.dimensions else 0
|
|
||||||
print(f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings")
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
print("Client-side Protobuf ZMQ warmup complete")
|
|
||||||
socket.close()
|
|
||||||
context.term()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during Protobuf ZMQ warmup: {e}")
|
|
||||||
|
|
||||||
class DeviceTimer:
|
|
||||||
"""Device timer"""
|
|
||||||
def __init__(self, name="", device=device):
|
|
||||||
self.name = name
|
|
||||||
self.device = device
|
|
||||||
self.start_time = 0
|
|
||||||
self.end_time = 0
|
|
||||||
|
|
||||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
|
||||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
else:
|
|
||||||
self.start_event = None
|
|
||||||
self.end_event = None
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def timing(self):
|
|
||||||
self.start()
|
|
||||||
yield
|
|
||||||
self.end()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.start_event.record()
|
|
||||||
else:
|
|
||||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
def end(self):
|
|
||||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
|
||||||
self.end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
else:
|
|
||||||
if embedding_mode == "sentence-transformers" and self.device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
self.end_time = time.time()
|
|
||||||
|
|
||||||
def elapsed_time(self):
|
|
||||||
if embedding_mode == "sentence-transformers" and torch.cuda.is_available():
|
|
||||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
|
||||||
else:
|
|
||||||
return self.end_time - self.start_time
|
|
||||||
|
|
||||||
def print_elapsed(self):
|
|
||||||
elapsed = self.elapsed_time()
|
|
||||||
print(f"[{self.name}] Elapsed time: {elapsed:.3f}s")
|
|
||||||
|
|
||||||
def process_batch_pytorch(texts_batch, ids_batch, missing_ids):
|
|
||||||
"""Process text batch"""
|
|
||||||
if not texts_batch:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
# Filter out empty texts and their corresponding IDs
|
|
||||||
valid_texts = []
|
|
||||||
valid_ids = []
|
|
||||||
for i, text in enumerate(texts_batch):
|
|
||||||
if text.strip(): # Only include non-empty texts
|
|
||||||
valid_texts.append(text)
|
|
||||||
valid_ids.append(ids_batch[i])
|
|
||||||
|
|
||||||
if not valid_texts:
|
|
||||||
print("WARNING: No valid texts in batch")
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
# Tokenize
|
|
||||||
token_timer = DeviceTimer("tokenization")
|
|
||||||
with token_timer.timing():
|
|
||||||
inputs = tokenizer(
|
|
||||||
valid_texts,
|
|
||||||
padding=True,
|
|
||||||
truncation=True,
|
|
||||||
max_length=512,
|
|
||||||
return_tensors="pt"
|
|
||||||
).to(device)
|
|
||||||
|
|
||||||
# Compute embeddings
|
|
||||||
embed_timer = DeviceTimer("embedding computation")
|
|
||||||
with embed_timer.timing():
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model(**inputs)
|
|
||||||
hidden_states = outputs.last_hidden_state
|
|
||||||
|
|
||||||
# Mean pooling
|
|
||||||
attention_mask = inputs['attention_mask']
|
|
||||||
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
|
|
||||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
|
||||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
|
||||||
batch_embeddings = sum_embeddings / sum_mask
|
|
||||||
embed_timer.print_elapsed()
|
|
||||||
|
|
||||||
return batch_embeddings.cpu().numpy()
|
|
||||||
|
|
||||||
# ZMQ server main loop - modified to use REP socket
|
|
||||||
context = zmq.Context()
|
|
||||||
socket = context.socket(zmq.ROUTER) # Changed to REP socket
|
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
|
||||||
print(f"INFO: ZMQ ROUTER server listening on port {zmq_port}")
|
|
||||||
|
|
||||||
# Set timeouts
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second receive timeout
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000) # 300 second send timeout
|
|
||||||
|
|
||||||
from . import embedding_pb2
|
|
||||||
|
|
||||||
print(f"INFO: Embedding server ready to serve requests")
|
|
||||||
|
|
||||||
# Start warmup thread if enabled
|
|
||||||
if enable_warmup and len(passages) > 0:
|
|
||||||
import threading
|
|
||||||
print(f"Warmup enabled: starting warmup thread")
|
|
||||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
|
||||||
warmup_thread.daemon = True
|
|
||||||
warmup_thread.start()
|
|
||||||
else:
|
|
||||||
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
parts = socket.recv_multipart()
|
|
||||||
|
|
||||||
# --- Restore robust message format detection ---
|
|
||||||
# Must check parts length to avoid IndexError
|
|
||||||
if len(parts) >= 3:
|
|
||||||
identity = parts[0]
|
|
||||||
# empty = parts[1] # We usually don't care about the middle empty frame
|
|
||||||
message = parts[2]
|
|
||||||
elif len(parts) == 2:
|
|
||||||
# Can also handle cases without empty frame
|
|
||||||
identity = parts[0]
|
|
||||||
message = parts[1]
|
|
||||||
else:
|
|
||||||
# If received message format is wrong, print warning and ignore it instead of crashing
|
|
||||||
print(f"WARNING: Received unexpected message format with {len(parts)} parts. Ignoring.")
|
|
||||||
continue
|
|
||||||
print(f"INFO: Received ZMQ request from client {identity.hex()[:8]}, size {len(message)} bytes")
|
|
||||||
|
|
||||||
# Handle control messages (MessagePack format)
|
|
||||||
try:
|
|
||||||
request_payload = msgpack.unpackb(message)
|
|
||||||
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
|
||||||
if request_payload[0] == "__QUERY_META_PATH__":
|
|
||||||
# Return the current meta path being used by the server
|
|
||||||
current_meta_path = getattr(passages, '_meta_path', '') if hasattr(passages, '_meta_path') else ''
|
|
||||||
response = [current_meta_path]
|
|
||||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif request_payload[0] == "__UPDATE_META_PATH__" and len(request_payload) >= 2:
|
|
||||||
# Update the server's meta path and reload passages
|
|
||||||
new_meta_path = request_payload[1]
|
|
||||||
try:
|
|
||||||
print(f"INFO: Updating server meta path to: {new_meta_path}")
|
|
||||||
# Reload passages from the new meta file
|
|
||||||
passages = load_passages_from_metadata(new_meta_path)
|
|
||||||
# Store the meta path for future queries
|
|
||||||
passages._meta_path = new_meta_path
|
|
||||||
response = ["SUCCESS"]
|
|
||||||
print(f"INFO: Successfully updated meta path and reloaded {len(passages)} passages")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to update meta path: {e}")
|
|
||||||
response = ["FAILED", str(e)]
|
|
||||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif request_payload[0] == "__QUERY_MODEL__":
|
|
||||||
# Return the current model being used by the server
|
|
||||||
response = [model_name]
|
|
||||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif request_payload[0] == "__UPDATE_MODEL__" and len(request_payload) >= 2:
|
|
||||||
# Update the server's embedding model
|
|
||||||
new_model_name = request_payload[1]
|
|
||||||
try:
|
|
||||||
print(f"INFO: Updating server model from {model_name} to: {new_model_name}")
|
|
||||||
|
|
||||||
# Clean up old model to free memory
|
|
||||||
if not use_mlx:
|
|
||||||
print("INFO: Releasing old model from memory...")
|
|
||||||
old_model = model
|
|
||||||
old_tokenizer = tokenizer
|
|
||||||
|
|
||||||
# Load new tokenizer first
|
|
||||||
print(f"Loading new tokenizer for {new_model_name}...")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(new_model_name, use_fast=True)
|
|
||||||
|
|
||||||
# Load new model
|
|
||||||
print(f"Loading new model {new_model_name}...")
|
|
||||||
model = AutoModel.from_pretrained(new_model_name).to(device).eval()
|
|
||||||
|
|
||||||
# Optimize new model
|
|
||||||
if cuda_available or mps_available:
|
|
||||||
try:
|
|
||||||
model = model.half()
|
|
||||||
model = torch.compile(model)
|
|
||||||
print(f"INFO: Using FP16 precision with model: {new_model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"WARNING: Model optimization failed: {e}")
|
|
||||||
|
|
||||||
# Now safely delete old model after new one is loaded
|
|
||||||
del old_model
|
|
||||||
del old_tokenizer
|
|
||||||
|
|
||||||
# Clear GPU cache if available
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
print("INFO: Cleared CUDA cache")
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
print("INFO: Cleared MPS cache")
|
|
||||||
|
|
||||||
# Force garbage collection
|
|
||||||
import gc
|
|
||||||
gc.collect()
|
|
||||||
print("INFO: Memory cleanup completed")
|
|
||||||
|
|
||||||
# Update model name
|
|
||||||
model_name = new_model_name
|
|
||||||
|
|
||||||
response = ["SUCCESS"]
|
|
||||||
print(f"INFO: Successfully updated model to: {new_model_name}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to update model: {e}")
|
|
||||||
response = ["FAILED", str(e)]
|
|
||||||
socket.send_multipart([identity, b'', msgpack.packb(response)])
|
|
||||||
continue
|
|
||||||
except:
|
|
||||||
# Not a control message, continue with normal protobuf processing
|
|
||||||
pass
|
|
||||||
|
|
||||||
e2e_start = time.time()
|
|
||||||
lookup_timer = DeviceTimer("text lookup")
|
|
||||||
|
|
||||||
# Parse request
|
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
|
||||||
req_proto.ParseFromString(message)
|
|
||||||
node_ids = req_proto.node_ids
|
|
||||||
print(f"INFO: Request for {len(node_ids)} node embeddings: {list(node_ids)}")
|
|
||||||
|
|
||||||
# Add debug information
|
|
||||||
if len(node_ids) > 0:
|
|
||||||
print(f"DEBUG: Node ID range: {min(node_ids)} to {max(node_ids)}")
|
|
||||||
|
|
||||||
# Look up texts
|
|
||||||
texts = []
|
|
||||||
missing_ids = []
|
|
||||||
with lookup_timer.timing():
|
|
||||||
for nid in node_ids:
|
|
||||||
txtinfo = passages[nid]
|
|
||||||
txt = txtinfo["text"]
|
|
||||||
if txt:
|
|
||||||
texts.append(txt)
|
|
||||||
else:
|
|
||||||
# If text is empty, we still need a placeholder for batch processing,
|
|
||||||
# but record its ID as missing
|
|
||||||
texts.append("")
|
|
||||||
missing_ids.append(nid)
|
|
||||||
lookup_timer.print_elapsed()
|
|
||||||
|
|
||||||
if missing_ids:
|
|
||||||
print(f"WARNING: Missing passages for IDs: {missing_ids}")
|
|
||||||
|
|
||||||
# Process batch
|
|
||||||
total_size = len(texts)
|
|
||||||
print(f"INFO: Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
|
||||||
|
|
||||||
all_embeddings = []
|
|
||||||
|
|
||||||
if total_size > max_batch_size:
|
|
||||||
print(f"INFO: Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
|
||||||
for i in range(0, total_size, max_batch_size):
|
|
||||||
end_idx = min(i + max_batch_size, total_size)
|
|
||||||
print(f"INFO: Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
|
||||||
|
|
||||||
chunk_texts = texts[i:end_idx]
|
|
||||||
chunk_ids = node_ids[i:end_idx]
|
|
||||||
|
|
||||||
if embedding_mode == "mlx":
|
|
||||||
embeddings_chunk = compute_embeddings_mlx(chunk_texts, model_name, batch_size=16)
|
|
||||||
elif embedding_mode == "openai":
|
|
||||||
embeddings_chunk = compute_embeddings_openai(chunk_texts, model_name)
|
|
||||||
else: # sentence-transformers
|
|
||||||
embeddings_chunk = process_batch_pytorch(chunk_texts, chunk_ids, missing_ids)
|
|
||||||
all_embeddings.append(embeddings_chunk)
|
|
||||||
|
|
||||||
if embedding_mode == "sentence-transformers":
|
|
||||||
if cuda_available:
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.empty_cache()
|
|
||||||
|
|
||||||
hidden = np.vstack(all_embeddings)
|
|
||||||
print(f"INFO: Combined embeddings shape: {hidden.shape}")
|
|
||||||
else:
|
|
||||||
if embedding_mode == "mlx":
|
|
||||||
hidden = compute_embeddings_mlx(texts, model_name, batch_size=16)
|
|
||||||
elif embedding_mode == "openai":
|
|
||||||
hidden = compute_embeddings_openai(texts, model_name)
|
|
||||||
else: # sentence-transformers
|
|
||||||
hidden = process_batch_pytorch(texts, node_ids, missing_ids)
|
|
||||||
|
|
||||||
# Serialize response
|
|
||||||
ser_start = time.time()
|
|
||||||
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
hidden_contiguous = np.ascontiguousarray(hidden, dtype=np.float32)
|
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
|
||||||
resp_proto.missing_ids.extend(missing_ids)
|
|
||||||
|
|
||||||
response_data = resp_proto.SerializeToString()
|
|
||||||
|
|
||||||
# REP socket sends a single response
|
|
||||||
socket.send_multipart([identity, b'', response_data])
|
|
||||||
|
|
||||||
ser_end = time.time()
|
|
||||||
|
|
||||||
print(f"INFO: Serialize time: {ser_end - ser_start:.6f} seconds")
|
|
||||||
|
|
||||||
if embedding_mode == "sentence-transformers":
|
|
||||||
if device.type == "cuda":
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elif device.type == "mps":
|
|
||||||
torch.mps.synchronize()
|
|
||||||
e2e_end = time.time()
|
|
||||||
print(f"INFO: ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
print("INFO: ZMQ socket timeout, continuing to listen")
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Error in ZMQ server: {e}")
|
|
||||||
try:
|
|
||||||
# Send empty response to maintain REQ-REP state
|
|
||||||
empty_resp = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
socket.send(empty_resp.SerializeToString())
|
|
||||||
except:
|
|
||||||
# If sending fails, recreate socket
|
|
||||||
socket.close()
|
|
||||||
socket = context.socket(zmq.REP)
|
|
||||||
socket.bind(f"tcp://127.0.0.1:{zmq_port}")
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 5000)
|
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
||||||
print("INFO: ZMQ socket recreated after error")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"ERROR: Failed to start embedding server: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def create_embedding_server(
|
|
||||||
domain="demo",
|
|
||||||
load_passages=True,
|
|
||||||
load_embeddings=False,
|
|
||||||
use_fp16=True,
|
|
||||||
use_int8=False,
|
|
||||||
use_cuda_graphs=False,
|
|
||||||
zmq_port=5555,
|
|
||||||
max_batch_size=128,
|
|
||||||
lazy_load_passages=False,
|
|
||||||
model_name="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
passages_file: Optional[str] = None,
|
|
||||||
embedding_mode: str = "sentence-transformers",
|
|
||||||
enable_warmup: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
原有的 create_embedding_server 函数保持不变
|
|
||||||
这个是阻塞版本,用于直接运行
|
|
||||||
"""
|
|
||||||
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file, embedding_mode, enable_warmup)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Embedding service")
|
|
||||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
|
||||||
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
|
|
||||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
|
||||||
parser.add_argument("--load-passages", action="store_true", default=True)
|
|
||||||
parser.add_argument("--load-embeddings", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
|
||||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
|
||||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
|
||||||
parser.add_argument("--lazy-load-passages", action="store_true", default=True)
|
|
||||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
|
||||||
help="Embedding model name")
|
|
||||||
parser.add_argument("--embedding-mode", type=str, default="sentence-transformers",
|
|
||||||
choices=["sentence-transformers", "mlx", "openai"],
|
|
||||||
help="Embedding backend mode")
|
|
||||||
parser.add_argument("--use-mlx", action="store_true", default=False, help="Use MLX backend for embeddings (deprecated: use --embedding-mode mlx)")
|
|
||||||
parser.add_argument("--disable-warmup", action="store_true", default=False, help="Disable warmup requests on server start")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Handle backward compatibility with use_mlx
|
|
||||||
embedding_mode = args.embedding_mode
|
|
||||||
if args.use_mlx:
|
|
||||||
embedding_mode = "mlx"
|
|
||||||
|
|
||||||
create_embedding_server(
|
|
||||||
domain=args.domain,
|
|
||||||
load_passages=args.load_passages,
|
|
||||||
load_embeddings=args.load_embeddings,
|
|
||||||
use_fp16=args.use_fp16,
|
|
||||||
use_int8=args.use_int8,
|
|
||||||
use_cuda_graphs=args.use_cuda_graphs,
|
|
||||||
zmq_port=args.zmq_port,
|
|
||||||
max_batch_size=args.max_batch_size,
|
|
||||||
lazy_load_passages=args.lazy_load_passages,
|
|
||||||
model_name=args.model_name,
|
|
||||||
passages_file=args.passages_file,
|
|
||||||
embedding_mode=embedding_mode,
|
|
||||||
enable_warmup=not args.disable_warmup,
|
|
||||||
)
|
|
||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: af2a26481e...25339b0341
@@ -142,12 +142,12 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
self,
|
self,
|
||||||
query: np.ndarray,
|
query: np.ndarray,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
zmq_port: Optional[int] = None,
|
||||||
complexity: int = 64,
|
complexity: int = 64,
|
||||||
beam_width: int = 1,
|
beam_width: int = 1,
|
||||||
prune_ratio: float = 0.0,
|
prune_ratio: float = 0.0,
|
||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: Optional[int] = None,
|
|
||||||
batch_size: int = 0,
|
batch_size: int = 0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -165,7 +165,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
- "global": Use global PQ queue size for selection (default)
|
- "global": Use global PQ queue size for selection (default)
|
||||||
- "local": Local pruning, sort and select best candidates
|
- "local": Local pruning, sort and select best candidates
|
||||||
- "proportional": Base selection on new neighbor count ratio
|
- "proportional": Base selection on new neighbor count ratio
|
||||||
expected_zmq_port: ZMQ port for embedding server
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
|
||||||
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
|
||||||
|
|
||||||
@@ -177,6 +177,11 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if not recompute_embeddings:
|
if not recompute_embeddings:
|
||||||
if self.is_pruned:
|
if self.is_pruned:
|
||||||
raise RuntimeError("Recompute is required for pruned index.")
|
raise RuntimeError("Recompute is required for pruned index.")
|
||||||
|
if recompute_embeddings:
|
||||||
|
if zmq_port is None:
|
||||||
|
raise ValueError(
|
||||||
|
"zmq_port must be provided if recompute_embeddings is True"
|
||||||
|
)
|
||||||
|
|
||||||
if query.dtype != np.float32:
|
if query.dtype != np.float32:
|
||||||
query = query.astype(np.float32)
|
query = query.astype(np.float32)
|
||||||
@@ -184,7 +189,10 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
faiss.normalize_L2(query)
|
faiss.normalize_L2(query)
|
||||||
|
|
||||||
params = faiss.SearchParametersHNSW()
|
params = faiss.SearchParametersHNSW()
|
||||||
params.zmq_port = expected_zmq_port
|
if zmq_port is not None:
|
||||||
|
params.zmq_port = (
|
||||||
|
zmq_port # C++ code won't use this if recompute_embeddings is False
|
||||||
|
)
|
||||||
params.efSearch = complexity
|
params.efSearch = complexity
|
||||||
params.beam_size = beam_width
|
params.beam_size = beam_width
|
||||||
|
|
||||||
|
|||||||
@@ -450,7 +450,7 @@ class LeannSearcher:
|
|||||||
prune_ratio=prune_ratio,
|
prune_ratio=prune_ratio,
|
||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
expected_zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - start_time
|
search_time = time.time() - start_time
|
||||||
|
|||||||
@@ -1,10 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
import os
|
|
||||||
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
@@ -41,7 +37,7 @@ Examples:
|
|||||||
leann search my-docs "query" # Search in my-docs index
|
leann search my-docs "query" # Search in my-docs index
|
||||||
leann ask my-docs "question" # Ask my-docs index
|
leann ask my-docs "question" # Ask my-docs index
|
||||||
leann list # List all stored indexes
|
leann list # List all stored indexes
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||||
@@ -49,10 +45,18 @@ Examples:
|
|||||||
# Build command
|
# Build command
|
||||||
build_parser = subparsers.add_parser("build", help="Build document index")
|
build_parser = subparsers.add_parser("build", help="Build document index")
|
||||||
build_parser.add_argument("index_name", help="Index name")
|
build_parser.add_argument("index_name", help="Index name")
|
||||||
build_parser.add_argument("--docs", type=str, required=True, help="Documents directory")
|
build_parser.add_argument(
|
||||||
build_parser.add_argument("--backend", type=str, default="hnsw", choices=["hnsw", "diskann"])
|
"--docs", type=str, required=True, help="Documents directory"
|
||||||
build_parser.add_argument("--embedding-model", type=str, default="facebook/contriever")
|
)
|
||||||
build_parser.add_argument("--force", "-f", action="store_true", help="Force rebuild")
|
build_parser.add_argument(
|
||||||
|
"--backend", type=str, default="hnsw", choices=["hnsw", "diskann"]
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-model", type=str, default="facebook/contriever"
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--force", "-f", action="store_true", help="Force rebuild"
|
||||||
|
)
|
||||||
build_parser.add_argument("--graph-degree", type=int, default=32)
|
build_parser.add_argument("--graph-degree", type=int, default=32)
|
||||||
build_parser.add_argument("--complexity", type=int, default=64)
|
build_parser.add_argument("--complexity", type=int, default=64)
|
||||||
build_parser.add_argument("--num-threads", type=int, default=1)
|
build_parser.add_argument("--num-threads", type=int, default=1)
|
||||||
@@ -68,12 +72,21 @@ Examples:
|
|||||||
search_parser.add_argument("--beam-width", type=int, default=1)
|
search_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
search_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
search_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
search_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
search_parser.add_argument(
|
||||||
|
"--pruning-strategy",
|
||||||
|
choices=["global", "local", "proportional"],
|
||||||
|
default="global",
|
||||||
|
)
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
ask_parser.add_argument("index_name", help="Index name")
|
ask_parser.add_argument("index_name", help="Index name")
|
||||||
ask_parser.add_argument("--llm", type=str, default="ollama", choices=["simulated", "ollama", "hf", "openai"])
|
ask_parser.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="ollama",
|
||||||
|
choices=["simulated", "ollama", "hf", "openai"],
|
||||||
|
)
|
||||||
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
ask_parser.add_argument("--model", type=str, default="qwen3:8b")
|
||||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
||||||
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
ask_parser.add_argument("--interactive", "-i", action="store_true")
|
||||||
@@ -82,7 +95,11 @@ Examples:
|
|||||||
ask_parser.add_argument("--beam-width", type=int, default=1)
|
ask_parser.add_argument("--beam-width", type=int, default=1)
|
||||||
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
ask_parser.add_argument("--prune-ratio", type=float, default=0.0)
|
||||||
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
ask_parser.add_argument("--recompute-embeddings", action="store_true")
|
||||||
ask_parser.add_argument("--pruning-strategy", choices=["global", "local", "proportional"], default="global")
|
ask_parser.add_argument(
|
||||||
|
"--pruning-strategy",
|
||||||
|
choices=["global", "local", "proportional"],
|
||||||
|
default="global",
|
||||||
|
)
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
list_parser = subparsers.add_parser("list", help="List all indexes")
|
list_parser = subparsers.add_parser("list", help="List all indexes")
|
||||||
@@ -93,13 +110,17 @@ Examples:
|
|||||||
print("Stored LEANN indexes:")
|
print("Stored LEANN indexes:")
|
||||||
|
|
||||||
if not self.indexes_dir.exists():
|
if not self.indexes_dir.exists():
|
||||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
print(
|
||||||
|
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
index_dirs = [d for d in self.indexes_dir.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
if not index_dirs:
|
if not index_dirs:
|
||||||
print("No indexes found. Use 'leann build <name> --docs <dir>' to create one.")
|
print(
|
||||||
|
"No indexes found. Use 'leann build <name> --docs <dir>' to create one."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Found {len(index_dirs)} indexes:")
|
print(f"Found {len(index_dirs)} indexes:")
|
||||||
@@ -110,13 +131,15 @@ Examples:
|
|||||||
print(f" {i}. {index_name} [{status}]")
|
print(f" {i}. {index_name} [{status}]")
|
||||||
if self.index_exists(index_name):
|
if self.index_exists(index_name):
|
||||||
meta_file = index_dir / "documents.leann.meta.json"
|
meta_file = index_dir / "documents.leann.meta.json"
|
||||||
size_mb = sum(f.stat().st_size for f in index_dir.iterdir() if f.is_file()) / (1024 * 1024)
|
size_mb = sum(
|
||||||
|
f.stat().st_size for f in index_dir.iterdir() if f.is_file()
|
||||||
|
) / (1024 * 1024)
|
||||||
print(f" Size: {size_mb:.1f} MB")
|
print(f" Size: {size_mb:.1f} MB")
|
||||||
|
|
||||||
if index_dirs:
|
if index_dirs:
|
||||||
example_name = index_dirs[0].name
|
example_name = index_dirs[0].name
|
||||||
print(f"\nUsage:")
|
print(f"\nUsage:")
|
||||||
print(f" leann search {example_name} \"your query\"")
|
print(f' leann search {example_name} "your query"')
|
||||||
print(f" leann ask {example_name} --interactive")
|
print(f" leann ask {example_name} --interactive")
|
||||||
|
|
||||||
def load_documents(self, docs_dir: str):
|
def load_documents(self, docs_dir: str):
|
||||||
@@ -179,7 +202,9 @@ Examples:
|
|||||||
index_path = self.get_index_path(index_name)
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
if not self.index_exists(index_name):
|
if not self.index_exists(index_name):
|
||||||
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
print(
|
||||||
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
searcher = LeannSearcher(index_path=index_path)
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
@@ -190,7 +215,7 @@ Examples:
|
|||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy
|
pruning_strategy=args.pruning_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Search results for '{query}' (top {len(results)}):")
|
print(f"Search results for '{query}' (top {len(results)}):")
|
||||||
@@ -204,7 +229,9 @@ Examples:
|
|||||||
index_path = self.get_index_path(index_name)
|
index_path = self.get_index_path(index_name)
|
||||||
|
|
||||||
if not self.index_exists(index_name):
|
if not self.index_exists(index_name):
|
||||||
print(f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it.")
|
print(
|
||||||
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir>' to create it."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"Starting chat with index '{index_name}'...")
|
print(f"Starting chat with index '{index_name}'...")
|
||||||
@@ -222,7 +249,7 @@ Examples:
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
user_input = input("\nYou: ").strip()
|
user_input = input("\nYou: ").strip()
|
||||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
if user_input.lower() in ["quit", "exit", "q"]:
|
||||||
print("Goodbye!")
|
print("Goodbye!")
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -236,7 +263,7 @@ Examples:
|
|||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy
|
pruning_strategy=args.pruning_strategy,
|
||||||
)
|
)
|
||||||
print(f"LEANN: {response}")
|
print(f"LEANN: {response}")
|
||||||
else:
|
else:
|
||||||
@@ -249,7 +276,7 @@ Examples:
|
|||||||
beam_width=args.beam_width,
|
beam_width=args.beam_width,
|
||||||
prune_ratio=args.prune_ratio,
|
prune_ratio=args.prune_ratio,
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
pruning_strategy=args.pruning_strategy
|
pruning_strategy=args.pruning_strategy,
|
||||||
)
|
)
|
||||||
print(f"LEANN: {response}")
|
print(f"LEANN: {response}")
|
||||||
|
|
||||||
@@ -277,6 +304,7 @@ Examples:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
import dotenv
|
import dotenv
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
cli = LeannCLI()
|
cli = LeannCLI()
|
||||||
|
|||||||
@@ -60,6 +60,9 @@ def compute_embeddings_sentence_transformers(
|
|||||||
"""
|
"""
|
||||||
Compute embeddings using SentenceTransformer with model caching
|
Compute embeddings using SentenceTransformer with model caching
|
||||||
"""
|
"""
|
||||||
|
# Handle empty input
|
||||||
|
if not texts:
|
||||||
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using SentenceTransformer, model: '{model_name}'"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ class LeannBackendSearcherInterface(ABC):
|
|||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
zmq_port: ZMQ port for embedding server communication
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
**kwargs: Backend-specific parameters
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -104,6 +104,10 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
# Try to use embedding server if available and requested
|
# Try to use embedding server if available and requested
|
||||||
if use_server_if_available:
|
if use_server_if_available:
|
||||||
try:
|
try:
|
||||||
|
# TODO: Maybe we can directly use this port here?
|
||||||
|
# For this internal method, it's ok to assume that the server is running
|
||||||
|
# on that port?
|
||||||
|
|
||||||
# Ensure we have a server with passages_file for compatibility
|
# Ensure we have a server with passages_file for compatibility
|
||||||
passages_source_file = (
|
passages_source_file = (
|
||||||
self.index_dir / f"{self.index_path.name}.meta.json"
|
self.index_dir / f"{self.index_path.name}.meta.json"
|
||||||
@@ -181,7 +185,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored PQ codes
|
||||||
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
pruning_strategy: PQ candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
zmq_port: ZMQ port for embedding server communication
|
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
|
||||||
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
**kwargs: Backend-specific parameters (e.g., batch_size, dedup_node_dis, etc.)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
Reference in New Issue
Block a user