Compare commits

..

6 Commits

Author SHA1 Message Date
Andy Lee
16705fc44a refactor: passage structure 2025-07-06 21:48:38 +00:00
Andy Lee
5611f708e9 docs: embedding pruning 2025-07-06 19:50:01 +00:00
Andy Lee
b4ae57b2c0 feat: auto discovery of packages and fix passage gen for diskann 2025-07-06 05:05:49 +00:00
Andy Lee
5659174635 fix: diskann zmq port and passages 2025-07-06 04:14:15 +00:00
Andy Lee
a38bc0a3fc refactor: embedding server manager 2025-07-06 01:54:46 +00:00
yichuan
449983c937 Merge pull request #1 from yichuan520030910320/debug_diskann_disable_pipe
debug_diskann_disable_pipe
2025-07-05 17:55:27 -07:00
17 changed files with 832 additions and 10047 deletions

3
.gitignore vendored
View File

@@ -71,3 +71,6 @@ test_indices*/
test_*.py test_*.py
!tests/** !tests/**
packages/leann-backend-diskann/third_party/DiskANN/_deps/ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json

View File

@@ -74,7 +74,7 @@ def main():
print(f"⏱️ Basic search time: {basic_time:.3f} seconds") print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<") print(">>> Basic search results <<<")
for i, res in enumerate(results, 1): for i, res in enumerate(results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# --- 3. Recompute search demo --- # --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...") print(f"\n[PHASE 3] Recompute search using embedding server...")
@@ -107,7 +107,7 @@ def main():
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds") print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<") print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1): for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# Compare results # Compare results
print(f"\n--- Result comparison ---") print(f"\n--- Result comparison ---")
@@ -116,8 +116,8 @@ def main():
print("\nBasic search vs Recompute results:") print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))): for i in range(min(len(results), len(recompute_results))):
basic_score = results[i]['score'] basic_score = results[i].score
recompute_score = recompute_results[i]['score'] recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score) score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}") print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")

View File

@@ -10,7 +10,6 @@ import asyncio
import os import os
import dotenv import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat from leann.api import LeannBuilder, LeannSearcher, LeannChat
import leann_backend_hnsw # Import to ensure backend registration
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -39,7 +38,7 @@ all_texts = []
for doc in documents: for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes: for node in nodes:
all_texts.append(node.text) all_texts.append(node.get_content())
INDEX_DIR = Path("./test_pdf_index") INDEX_DIR = Path("./test_pdf_index")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
@@ -51,7 +50,7 @@ if not INDEX_DIR.exists():
# CSR compact mode with recompute # CSR compact mode with recompute
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="diskann",
embedding_model="facebook/contriever", embedding_model="facebook/contriever",
graph_degree=32, graph_degree=32,
complexity=64, complexity=64,
@@ -74,7 +73,7 @@ async def main():
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
print(f"You: {query}") print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever") chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
print(f"Leann: {chat_response}") print(f"Leann: {chat_response}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1 @@
from . import diskann_backend

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, Any, List
import contextlib import contextlib
import threading import threading
import time import time
@@ -11,20 +11,22 @@ import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_backend from leann.registry import register_backend
from leann.interface import ( from leann.interface import (
LeannBackendFactoryInterface, LeannBackendFactoryInterface,
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendSearcherInterface LeannBackendSearcherInterface
) )
from . import _diskannpy as diskannpy def _get_diskann_metrics():
from . import _diskannpy as diskannpy
METRIC_MAP = { return {
"mips": diskannpy.Metric.INNER_PRODUCT, "mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2, "l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE, "cosine": diskannpy.Metric.COSINE,
} }
@contextlib.contextmanager @contextlib.contextmanager
def chdir(path): def chdir(path):
@@ -35,103 +37,13 @@ def chdir(path):
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: str): def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
num_vectors, dim = data.shape num_vectors, dim = data.shape
with open(file_path, 'wb') as f: with open(file_path, 'wb') as f:
f.write(struct.pack('I', num_vectors)) f.write(struct.pack('I', num_vectors))
f.write(struct.pack('I', dim)) f.write(struct.pack('I', dim))
f.write(data.tobytes()) f.write(data.tobytes())
def _check_port(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
class EmbeddingServerManager:
def __init__(self):
self.server_process = None
self.server_port = None
atexit.register(self.stop_server)
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
return True
# 检查端口是否已被其他无关进程占用
if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
return True
print(f"INFO: Starting session-level embedding server as a background process...")
try:
command = [
sys.executable,
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
"--zmq-port", str(port),
"--model-name", model_name
]
project_root = Path(__file__).parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
text=True,
encoding='utf-8'
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
def _log_monitor(self):
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[EmbeddingServer LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[EmbeddingServer ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self):
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: Server process did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None
@register_backend("diskann") @register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface): class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
@@ -143,23 +55,21 @@ class DiskannBackend(LeannBackendFactoryInterface):
path = Path(index_path) path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json" meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists(): if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.") raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f: with open(meta_path, 'r') as f:
meta = json.load(f) meta = json.load(f)
dimensions = meta.get("dimensions") # Pass essential metadata to the searcher
if not dimensions: kwargs['meta'] = meta
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
kwargs['dimensions'] = dimensions
return DiskannSearcher(index_path, **kwargs) return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface): class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
def build(self, data: np.ndarray, index_path: str, **kwargs):
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
@@ -174,8 +84,15 @@ class DiskannBuilder(LeannBackendBuilderInterface):
data_filename = f"{index_prefix}_data.bin" data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename) _write_vectors_to_bin(data, index_dir / data_filename)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower() metric_str = build_kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(metric_str) metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
@@ -191,6 +108,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...") print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try: try:
from . import _diskannpy as diskannpy
with chdir(index_dir): with chdir(index_dir):
diskannpy.build_disk_float_index( diskannpy.build_disk_float_index(
metric_enum, metric_enum,
@@ -215,33 +133,53 @@ class DiskannBuilder(LeannBackendBuilderInterface):
class DiskannSearcher(LeannBackendSearcherInterface): class DiskannSearcher(LeannBackendSearcherInterface):
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
self.meta = kwargs.get("meta", {})
if not self.meta:
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path) path = Path(index_path)
index_dir = path.parent self.index_dir = path.parent
index_prefix = path.stem self.index_prefix = path.stem
metric_str = kwargs.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(metric_str) # Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
# Extract parameters for DiskANN
distance_metric = kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(distance_metric)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
num_threads = kwargs.get("num_threads", 8) num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
dimensions = kwargs.get("dimensions") self.zmq_port = kwargs.get("zmq_port", 6666)
if not dimensions:
raise ValueError("Vector dimension not provided to DiskannSearcher.")
try: try:
full_index_prefix = str(index_dir / index_prefix) from . import _diskannpy as diskannpy
full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex( self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", "" metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
) )
self.num_threads = num_threads self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager() self.embedding_server_manager = EmbeddingServerManager(
backend_module_name="leann_backend_diskann.embedding_server"
)
print("✅ DiskANN index loaded successfully.") print("✅ DiskANN index loaded successfully.")
except Exception as e: except Exception as e:
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}") print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]: def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256) complexity = kwargs.get("complexity", 256)
beam_width = kwargs.get("beam_width", 4) beam_width = kwargs.get("beam_width", 4)
@@ -252,15 +190,32 @@ class DiskannSearcher(LeannBackendSearcherInterface):
prune_ratio = kwargs.get("prune_ratio", 0.0) prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False) batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False) global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings: if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running") print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
zmq_port = kwargs.get("zmq_port", 6666) if not self.embedding_model:
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2") raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
if not self.embedding_server_manager.start_server(zmq_port, embedding_model): passages_file = kwargs.get("passages_file")
print(f"WARNING: Failed to start embedding server, falling back to PQ computation") if not passages_file:
kwargs['recompute_beighbor_embeddings'] = False # Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else:
raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.")
server_started = self.embedding_server_manager.start_server(
port=self.zmq_port,
model_name=self.embedding_model,
distance_metric=kwargs.get("distance_metric", "mips"),
passages_file=passages_file
)
if not server_started:
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
@@ -283,11 +238,23 @@ class DiskannSearcher(LeannBackendSearcherInterface):
batch_recompute, batch_recompute,
global_pruning global_pruning
) )
return {"labels": labels, "distances": distances}
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e: except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}") print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0] batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64), return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)],
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)} "distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self): def __del__(self):

View File

@@ -7,6 +7,8 @@ import pickle
import argparse import argparse
import threading import threading
import time import time
import json
from typing import Dict, Any, Optional, Union
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import os import os
@@ -17,49 +19,76 @@ import numpy as np
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m" RESET = "\033[0m"
# 简化的文档存储 - 替代 LazyPassages # --- New Passage Loader from HNSW backend ---
class SimpleDocumentStore: class SimplePassageLoader:
"""简化的文档存储支持任意ID""" """
def __init__(self, documents: dict = None): Simple passage loader that replaces config.py dependencies
self.documents = documents or {} """
# 默认演示文档 def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.default_docs = { self.passages_data = passages_data or {}
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
def __getitem__(self, doc_id): def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
doc_id = int(doc_id) """Get passage by ID"""
str_id = str(passage_id)
if str_id in self.passages_data:
return {"text": self.passages_data[str_id]}
else:
# Return empty text for missing passages
return {"text": ""}
# 优先使用指定的文档 def __len__(self) -> int:
if doc_id in self.documents: return len(self.passages_data)
return {"text": self.documents[doc_id]}
# 其次使用默认演示文档 def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
if doc_id in self.default_docs: """
return {"text": self.default_docs[doc_id]} Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
from pathlib import Path
import pickle
# 对于任意其他ID返回通用文档 if not os.path.exists(passages_file):
fallback_docs = [ raise FileNotFoundError(f"Passages file {passages_file} not found.")
"This is a general document about technology and programming concepts.",
"This document discusses machine learning and artificial intelligence topics.",
"This content covers data structures, algorithms, and computer science fundamentals.",
"This is a document about software engineering and development practices.",
"This content focuses on databases, data management, and information systems."
]
# 根据ID选择一个fallback文档 if not passages_file.endswith('.jsonl'):
fallback_text = fallback_docs[doc_id % len(fallback_docs)] raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
return {"text": f"[ID:{doc_id}] {fallback_text}"}
def __len__(self): # Load label map (int -> string_id)
return len(self.documents) + len(self.default_docs) passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread( def create_embedding_server_thread(
zmq_port=5555, zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128, max_batch_size=128,
passages_file: Optional[str] = None,
): ):
""" """
在当前线程中创建并运行 embedding server 在当前线程中创建并运行 embedding server
@@ -109,15 +138,14 @@ def create_embedding_server_thread(
except Exception as e: except Exception as e:
print(f"WARNING: Model optimization failed: {e}") print(f"WARNING: Model optimization failed: {e}")
# 默认演示文档 # Load passages from file if provided
demo_documents = { if passages_file and os.path.exists(passages_file):
0: "Python is a high-level, interpreted language known for simplicity.", passages = load_passages_from_file(passages_file)
1: "Machine learning builds systems that learn from data.", else:
2: "Data structures like arrays, lists, and graphs organize data.", print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
} passages = SimplePassageLoader()
passages = SimpleDocumentStore(demo_documents) print(f"INFO: Loaded {len(passages)} passages.")
print(f"INFO: Loaded {len(passages)} demo documents")
class DeviceTimer: class DeviceTimer:
"""设备计时器""" """设备计时器"""
@@ -264,7 +292,13 @@ def create_embedding_server_thread(
for nid in node_ids: for nid in node_ids:
txtinfo = passages[nid] txtinfo = passages[nid]
txt = txtinfo["text"] txt = txtinfo["text"]
if txt:
texts.append(txt) texts.append(txt)
else:
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
# 但将其ID记录为缺失
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed() lookup_timer.print_elapsed()
if missing_ids: if missing_ids:
@@ -360,18 +394,20 @@ def create_embedding_server(
max_batch_size=128, max_batch_size=128,
lazy_load_passages=False, lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
): ):
""" """
原有的 create_embedding_server 函数保持不变 原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行 这个是阻塞版本,用于直接运行
""" """
create_embedding_server_thread(zmq_port, model_name, max_batch_size) create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service") parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name") parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--load-passages", action="store_true", default=True) parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False) parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False) parser.add_argument("--use-fp16", action="store_true", default=False)
@@ -394,4 +430,5 @@ if __name__ == "__main__":
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages, lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name, model_name=args.model_name,
passages_file=args.passages_file,
) )

View File

@@ -468,16 +468,27 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
# --- Write CSR HNSW graph data using unified function --- # --- Write CSR HNSW graph data using unified function ---
print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...") print(f"[{time.time() - start_time:.2f}s] Writing CSR HNSW graph data in FAISS-compatible order...")
# Determine storage fourcc based on prune_embeddings # Determine storage fourcc and data based on prune_embeddings
output_storage_fourcc = NULL_INDEX_FOURCC if prune_embeddings else (storage_fourcc if 'storage_fourcc' in locals() else NULL_INDEX_FOURCC)
if prune_embeddings: if prune_embeddings:
print(f" Pruning embeddings: Writing NULL storage marker.") print(f" Pruning embeddings: Writing NULL storage marker.")
output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b''
else:
# Keep embeddings - read and preserve original storage data
if storage_fourcc and storage_fourcc != NULL_INDEX_FOURCC:
print(f" Preserving embeddings: Reading original storage data...")
storage_data = f_in.read() # Read remaining storage data
output_storage_fourcc = storage_fourcc
print(f" Read {len(storage_data)} bytes of storage data")
else:
print(f" No embeddings found in original file (NULL storage)")
output_storage_fourcc = NULL_INDEX_FOURCC
storage_data = b'' storage_data = b''
# Use the unified write function # Use the unified write function
write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np, write_compact_format(f_out, original_hnsw_data, assign_probas_np, cum_nneighbor_per_level_np,
levels_np, compact_level_ptr, compact_node_offsets_np, levels_np, compact_level_ptr, compact_node_offsets_np,
compact_neighbors_data, output_storage_fourcc, storage_data if not prune_embeddings else b'') compact_neighbors_data, output_storage_fourcc, storage_data)
# Clean up memory # Clean up memory
del assign_probas_np, cum_nneighbor_per_level_np, levels_np del assign_probas_np, cum_nneighbor_per_level_np, levels_np

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any, List
import contextlib import contextlib
import threading import threading
import time import time
@@ -11,7 +11,9 @@ import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager
from .convert_to_csr import convert_hnsw_graph_to_csr from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend from leann.registry import register_backend
@@ -29,118 +31,6 @@ def get_metric_map():
"cosine": faiss.METRIC_INNER_PRODUCT, "cosine": faiss.METRIC_INNER_PRODUCT,
} }
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
class HNSWEmbeddingServerManager:
"""
HNSW-specific embedding server manager that handles the lifecycle of the embedding server process.
Mirrors the DiskANN EmbeddingServerManager architecture.
"""
def __init__(self):
self.server_process = None
self.server_port = None
atexit.register(self.stop_server)
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"):
"""
Start the HNSW embedding server process.
Args:
port: ZMQ port for the server
model_name: Name of the embedding model to use
passages_file: Optional path to passages JSON file
distance_metric: The distance metric to use
"""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
return True
# Check if port is already in use
if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external HNSW server is running and connecting to it.")
return True
print(f"INFO: Starting session-level HNSW embedding server as a background process...")
try:
command = [
sys.executable,
"-m", "leann_backend_hnsw.hnsw_embedding_server",
"--zmq-port", str(port),
"--model-name", model_name,
"--distance-metric", distance_metric
]
if passages_file:
command.extend(["--passages-file", str(passages_file)])
project_root = Path(__file__).parent.parent.parent.parent
print(f"INFO: Running HNSW command from project root: {project_root}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding='utf-8'
)
self.server_port = port
print(f"INFO: HNSW server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ HNSW embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: HNSW server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: HNSW server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
except Exception as e:
print(f"❌ ERROR: Failed to start HNSW embedding server process: {e}")
return False
def _log_monitor(self):
"""Monitor server logs"""
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[HNSWEmbeddingServer LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[HNSWEmbeddingServer ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e:
print(f"HNSW Log monitor error: {e}")
def stop_server(self):
"""Stop the HNSW embedding server process"""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating HNSW session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: HNSW server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: HNSW server process did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None
@register_backend("hnsw") @register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface): class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod @staticmethod
@@ -152,16 +42,12 @@ class HNSWBackend(LeannBackendFactoryInterface):
path = Path(index_path) path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json" meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists(): if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.") raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.")
with open(meta_path, 'r') as f: with open(meta_path, 'r') as f:
meta = json.load(f) meta = json.load(f)
dimensions = meta.get("dimensions") kwargs['meta'] = meta
if not dimensions:
raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.")
kwargs['dimensions'] = dimensions
return HNSWSearcher(index_path, **kwargs) return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface): class HNSWBuilder(LeannBackendBuilderInterface):
@@ -189,7 +75,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_recompute and not self.is_compact: if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency") raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, index_path: str, **kwargs): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""Build HNSW index using FAISS""" """Build HNSW index using FAISS"""
from . import faiss from . import faiss
@@ -204,6 +90,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if not data.flags['C_CONTIGUOUS']: if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data) data = np.ascontiguousarray(data)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
metric_str = self.distance_metric.lower() metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str) metric_enum = get_metric_map().get(metric_str)
if metric_enum is None: if metric_enum is None:
@@ -234,9 +126,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_compact: if self.is_compact:
self._convert_to_csr(index_file) self._convert_to_csr(index_file)
if self.is_recompute:
self._generate_passages_file(index_dir, index_prefix, **kwargs)
except Exception as e: except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}") print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise raise
@@ -270,30 +159,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
print(f"💥 ERROR: CSR conversion failed. Exception: {e}") print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise raise
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode"""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation")
return
# Generate node_id to text mapping
passages_data = {}
for node_id, chunk in enumerate(chunks):
passages_data[str(node_id)] = chunk["text"]
# Save passages file
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
# Don't raise - this is not critical for index building
pass
class HNSWSearcher(LeannBackendSearcherInterface): class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]: def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
@@ -376,47 +241,58 @@ class HNSWSearcher(LeannBackendSearcherInterface):
def __init__(self, index_path: str, **kwargs): def __init__(self, index_path: str, **kwargs):
from . import faiss from . import faiss
path = Path(index_path) self.meta = kwargs.get("meta", {})
index_dir = path.parent if not self.meta:
index_prefix = path.stem raise ValueError("HNSWSearcher requires metadata from .meta.json.")
# Store configuration and paths for later use self.dimensions = self.meta.get("dimensions")
self.config = kwargs.copy() if not self.dimensions:
self.config["index_path"] = index_path raise ValueError("Dimensions not found in Leann metadata.")
self.index_dir = index_dir
self.index_prefix = index_prefix
metric_str = self.config.get("distance_metric", "mips").lower() self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(metric_str) metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
dimensions = self.config.get("dimensions") self.embedding_model = self.meta.get("embedding_model")
if not dimensions: if not self.embedding_model:
raise ValueError("Vector dimension not provided to HNSWSearcher.") print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
index_file = index_dir / f"{index_prefix}.index" path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
index_file = self.index_dir / f"{self.index_prefix}.index"
if not index_file.exists(): if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}") raise FileNotFoundError(f"HNSW index file not found at {index_file}")
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file) self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
# Validate configuration constraints # Validate configuration constraints
if not self.is_compact and self.config.get("is_skip_neighbors", False): if not self.is_compact and kwargs.get("is_skip_neighbors", False):
raise ValueError("is_skip_neighbors can only be used with is_compact=True") raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if self.config.get("is_recompute", False) and self.config.get("external_storage_path"): if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"):
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously") raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
hnsw_config = faiss.HNSWIndexConfig() hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact hnsw_config.is_compact = self.is_compact
# Apply additional configuration options with strict validation # Apply additional configuration options with strict validation
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False) hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False)
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False) hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0) hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = self.config.get("external_storage_path") hnsw_config.external_storage_path = kwargs.get("external_storage_path")
hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute: if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.") raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
@@ -431,85 +307,72 @@ class HNSWSearcher(LeannBackendSearcherInterface):
else: else:
print("✅ Standard HNSW index loaded successfully.") print("✅ Standard HNSW index loaded successfully.")
self.metric_str = metric_str self.embedding_server_manager = EmbeddingServerManager(
self.embedding_server_manager = HNSWEmbeddingServerManager() backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
)
def _get_index_file(self, index_dir: Path, index_prefix: str) -> Path:
"""Get the appropriate index file path based on format"""
# We always use the same filename now, format is detected internally
return index_dir / f"{index_prefix}.index"
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""Search using HNSW index with optional recompute functionality""" """Search using HNSW index with optional recompute functionality"""
from . import faiss from . import faiss
# Merge config with search-time kwargs
search_config = self.config.copy()
search_config.update(kwargs)
ef = search_config.get("ef", 200) # Size of the dynamic candidate list for search ef = kwargs.get("ef", 200)
# Recompute parameters
zmq_port = search_config.get("zmq_port", 5557)
embedding_model = search_config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
passages_file = search_config.get("passages_file", None)
# For recompute mode, try to find the passages file automatically
if self.is_pruned and not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
print(f"DEBUG: Checking for passages file at: {potential_passages_file}")
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Found passages file for recompute mode: {passages_file}")
else:
print(f"WARNING: No passages file found for recompute mode at {potential_passages_file}")
# If index is pruned (embeddings removed), we MUST start embedding server for recompute
if self.is_pruned: if self.is_pruned:
print(f"INFO: Index is pruned - starting embedding server for recompute") print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.")
if not self.embedding_model:
raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.")
# CRITICAL: Check passages file exists - fail fast if not passages_file = kwargs.get("passages_file")
if not passages_file: if not passages_file:
raise RuntimeError(f"FATAL: Index is pruned but no passages file found. Cannot proceed with recompute mode.") # Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
# Check if server is already running first passage_source = self.meta['passage_sources'][0]
if _check_port(zmq_port): passages_file = passage_source['path']
print(f"INFO: Embedding server already running on port {zmq_port}") print(f"INFO: Found passages file from metadata: {passages_file}")
else: else:
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str): raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server(
port=zmq_port,
model_name=self.embedding_model,
passages_file=passages_file,
distance_metric=self.distance_metric
)
if not server_started:
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}") raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
# Give server extra time to fully initialize
print(f"INFO: Waiting for embedding server to fully initialize...")
time.sleep(3)
# Final verification
if not _check_port(zmq_port):
raise RuntimeError(f"Embedding server failed to start listening on port {zmq_port}")
else:
print(f"INFO: Index has embeddings stored - no recompute needed")
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)
if query.ndim == 1: if query.ndim == 1:
query = np.expand_dims(query, axis=0) query = np.expand_dims(query, axis=0)
# Normalize query if using cosine similarity if self.distance_metric == "cosine":
if self.metric_str == "cosine":
faiss.normalize_L2(query) faiss.normalize_L2(query)
try: try:
# Set search parameter params = faiss.SearchParametersHNSW()
self._index.hnsw.efSearch = ef params.efSearch = ef
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
# Prepare output arrays for the older FAISS SWIG API
batch_size = query.shape[0] batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32) distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64) labels = np.empty((batch_size, top_k), dtype=np.int64)
# Use standard FAISS search - recompute is handled internally by FAISS self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
return {"labels": labels, "distances": distances} # Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e: except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}") print(f"💥 ERROR: HNSW search failed. Exception: {e}")

View File

@@ -58,21 +58,46 @@ class SimplePassageLoader:
def load_passages_from_file(passages_file: str) -> SimplePassageLoader: def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
""" """
Load passages from a JSON file Load passages from a JSONL file with label map support
Expected format: {"passage_id": "passage_text", ...} Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
""" """
if not os.path.exists(passages_file): if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.") raise FileNotFoundError(f"Passages file {passages_file} not found.")
return SimplePassageLoader()
try: if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f: with open(passages_file, 'r', encoding='utf-8') as f:
passages_data = json.load(f) for line in f:
print(f"Loaded {len(passages_data)} passages from {passages_file}") if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data) return SimplePassageLoader(passages_data)
except Exception as e:
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader()
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,

View File

@@ -1,17 +1,7 @@
# This file makes the 'leann' directory a Python package. # packages/leann-core/src/leann/__init__.py
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends
from .api import LeannBuilder, LeannSearcher, LeannChat, SearchResult autodiscover_backends()
# Import backends to ensure they are registered __all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]
try:
import leann_backend_hnsw
except ImportError:
pass
try:
import leann_backend_diskann
except ImportError:
pass
__all__ = ['LeannBuilder', 'LeannSearcher', 'LeannChat', 'SearchResult']

View File

@@ -7,6 +7,8 @@ import json
from pathlib import Path from pathlib import Path
import openai import openai
from dataclasses import dataclass, field from dataclasses import dataclass, field
import uuid
import pickle
# --- Helper Functions for Embeddings --- # --- Helper Functions for Embeddings ---
@@ -56,11 +58,45 @@ def _get_embedding_dimensions(model_name: str) -> int:
@dataclass @dataclass
class SearchResult: class SearchResult:
"""Represents a single search result.""" """Represents a single search result."""
id: int id: str
score: float score: float
text: str text: str
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager:
"""Manages passage data and lazy loading from JSONL files."""
def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not os.path.exists(index_file):
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, 'rb') as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
def get_passage(self, passage_id: str) -> Dict[str, Any]:
"""Lazy load a passage by ID."""
for passage_file, offset_map in self.offset_maps.items():
if passage_id in offset_map:
offset = offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
line = f.readline()
return json.loads(line)
raise KeyError(f"Passage ID not found: {passage_id}")
# --- Core Classes --- # --- Core Classes ---
class LeannBuilder: class LeannBuilder:
@@ -82,7 +118,26 @@ class LeannBuilder:
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.") print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
self.chunks.append({"text": text, "metadata": metadata or {}}) if metadata is None:
metadata = {}
# Check if ID is provided in metadata
passage_id = metadata.get('id')
if passage_id is None:
passage_id = str(uuid.uuid4())
else:
# Validate uniqueness
existing_ids = {chunk['id'] for chunk in self.chunks}
if passage_id in existing_ids:
raise ValueError(f"Duplicate passage ID: {passage_id}")
# Store the definitive ID with the chunk
chunk_data = {
"id": passage_id,
"text": text,
"metadata": metadata
}
self.chunks.append(chunk_data)
def build_index(self, index_path: str): def build_index(self, index_path: str):
if not self.chunks: if not self.chunks:
@@ -92,28 +147,65 @@ class LeannBuilder:
self.dimensions = _get_embedding_dimensions(self.embedding_model) self.dimensions = _get_embedding_dimensions(self.embedding_model)
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}") print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
# Ensure the directory exists
index_dir.mkdir(parents=True, exist_ok=True)
# Create the passages.jsonl file and offset index
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
offset_map = {}
with open(passages_file, 'w', encoding='utf-8') as f:
for chunk in self.chunks:
offset = f.tell()
passage_data = {
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"]
}
json.dump(passage_data, f, ensure_ascii=False)
f.write('\n')
offset_map[chunk["id"]] = offset
# Save the offset map
with open(offset_file, 'wb') as f:
pickle.dump(offset_map, f)
# Compute embeddings
texts_to_embed = [c["text"] for c in self.chunks] texts_to_embed = [c["text"] for c in self.chunks]
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model) embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
# Extract string IDs for the backend
string_ids = [chunk["id"] for chunk in self.chunks]
# Build the vector index
current_backend_kwargs = self.backend_kwargs.copy() current_backend_kwargs = self.backend_kwargs.copy()
current_backend_kwargs['dimensions'] = self.dimensions current_backend_kwargs['dimensions'] = self.dimensions
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
build_kwargs = current_backend_kwargs.copy() builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
build_kwargs['chunks'] = self.chunks
builder_instance.build(embeddings, index_path, **build_kwargs)
index_dir = Path(index_path).parent # Create the lightweight meta.json file
leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json" leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = { meta_data = {
"version": "0.1.0", "version": "1.0",
"backend_name": self.backend_name, "backend_name": self.backend_name,
"embedding_model": self.embedding_model, "embedding_model": self.embedding_model,
"dimensions": self.dimensions, "dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs, "backend_kwargs": self.backend_kwargs,
"num_chunks": len(self.chunks), "passage_sources": [
"chunks": self.chunks, {
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file)
}
]
} }
with open(leann_meta_path, 'w', encoding='utf-8') as f: with open(leann_meta_path, 'w', encoding='utf-8') as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
@@ -136,14 +228,16 @@ class LeannSearcher:
backend_name = self.meta_data['backend_name'] backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model'] self.embedding_model = self.meta_data['embedding_model']
# Initialize the passage manager
passage_sources = self.meta_data.get('passage_sources', [])
self.passage_manager = PassageManager(passage_sources)
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.") raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
final_kwargs = self.meta_data.get("backend_kwargs", {}) final_kwargs = backend_kwargs.copy()
final_kwargs.update(backend_kwargs) final_kwargs['meta'] = self.meta_data
if 'dimensions' not in final_kwargs:
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.") print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
@@ -155,15 +249,17 @@ class LeannSearcher:
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
enriched_results = [] enriched_results = []
for label, dist in zip(results['labels'][0], results['distances'][0]): for string_id, dist in zip(results['labels'][0], results['distances'][0]):
if label < len(self.meta_data['chunks']): try:
chunk_info = self.meta_data['chunks'][label] passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult( enriched_results.append(SearchResult(
id=label, id=string_id,
score=dist, score=dist,
text=chunk_info['text'], text=passage_data['text'],
metadata=chunk_info.get('metadata', {}) metadata=passage_data.get('metadata', {})
)) ))
except KeyError:
print(f"WARNING: Passage ID '{string_id}' not found in passage files")
return enriched_results return enriched_results

View File

@@ -0,0 +1,132 @@
import os
import threading
import time
import atexit
import socket
import subprocess
import sys
from pathlib import Path
from typing import Optional
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
class EmbeddingServerManager:
"""
A generic manager for handling the lifecycle of a backend-specific embedding server process.
"""
def __init__(self, backend_module_name: str):
"""
Initializes the manager for a specific backend.
Args:
backend_module_name (str): The full module name of the backend's server script.
e.g., "leann_backend_diskann.embedding_server"
"""
self.backend_module_name = backend_module_name
self.server_process: Optional[subprocess.Popen] = None
self.server_port: Optional[int] = None
atexit.register(self.stop_server)
def start_server(self, port: int, model_name: str, **kwargs) -> bool:
"""
Starts the embedding server process.
Args:
port (int): The ZMQ port for the server.
model_name (str): The name of the embedding model to use.
**kwargs: Additional arguments for the server (e.g., passages_file, distance_metric).
Returns:
bool: True if the server is started successfully or already running, False otherwise.
"""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
return True
if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external server is running.")
return True
print(f"INFO: Starting session-level embedding server for '{self.backend_module_name}'...")
try:
command = [
sys.executable,
"-m", self.backend_module_name,
"--zmq-port", str(port),
"--model-name", model_name
]
# Add extra arguments for specific backends
if "passages_file" in kwargs and kwargs["passages_file"]:
command.extend(["--passages-file", str(kwargs["passages_file"])])
# if "distance_metric" in kwargs and kwargs["distance_metric"]:
# command.extend(["--distance-metric", kwargs["distance_metric"]])
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
text=True,
encoding='utf-8'
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr."""
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[{self.backend_module_name} LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[{self.backend_module_name} ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self):
"""Stops the embedding server process if it's running."""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: Server process did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None

View File

@@ -1,6 +1,9 @@
# packages/leann-core/src/leann/registry.py # packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING from typing import Dict, TYPE_CHECKING
import importlib
import importlib.metadata
if TYPE_CHECKING: if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface from leann.interface import LeannBackendFactoryInterface
@@ -13,3 +16,21 @@ def register_backend(name: str):
BACKEND_REGISTRY[name] = cls BACKEND_REGISTRY[name] = cls
return cls return cls
return decorator return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata['name']
if dist_name.startswith('leann-backend-'):
backend_module_name = dist_name.replace('-', '_')
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
print("INFO: Backend auto-discovery finished.")

View File

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,68 @@
# HNSW Index Storage Optimization
This document explains the storage optimization features available in the HNSW backend.
## Storage Modes
The HNSW backend supports two orthogonal optimization techniques:
### 1. CSR Compression (`is_compact=True`)
- Converts the graph structure from standard format to Compressed Sparse Row (CSR) format
- Reduces memory overhead from graph adjacency storage
- Maintains all embedding data for direct access
### 2. Embedding Pruning (`is_recompute=True`)
- Removes embedding vectors from the index file
- Replaces them with a NULL storage marker
- Requires recomputation via embedding server during search
- Must be used with `is_compact=True` for efficiency
## Performance Impact
**Storage Reduction (100 vectors, 384 dimensions):**
```
Standard format: 168 KB (embeddings + graph)
CSR only: 160 KB (embeddings + compressed graph)
CSR + Pruned: 6 KB (compressed graph only)
```
**Key Benefits:**
- **CSR compression**: ~5% size reduction from graph optimization
- **Embedding pruning**: ~95% size reduction by removing embeddings
- **Combined**: Up to 96% total storage reduction
## Usage
```python
# Standard format (largest)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=False,
is_recompute=False
)
# CSR compressed (medium)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=True,
is_recompute=False
)
# CSR + Pruned (smallest, requires embedding server)
builder = LeannBuilder(
backend_name="hnsw",
is_compact=True, # Required for pruning
is_recompute=True # Default: enabled
)
```
## Trade-offs
| Mode | Storage | Search Speed | Memory Usage | Setup Complexity |
|------|---------|--------------|--------------|------------------|
| Standard | Largest | Fastest | Highest | Simple |
| CSR | Medium | Fast | Medium | Simple |
| CSR + Pruned | Smallest | Slower* | Lowest | Complex** |
*Requires network round-trip to embedding server for recomputation
**Needs embedding server and passages file for search

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python3
"""
Sanity check script to verify HNSW index pruning effectiveness.
Tests the difference in file sizes between pruned and non-pruned indices.
"""
import os
import sys
import tempfile
from pathlib import Path
import numpy as np
import json
# Add the project root to the Python path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
# Import backend packages to trigger plugin registration
import leann_backend_hnsw
from leann.api import LeannBuilder
def create_sample_documents(num_docs=1000):
"""Create sample documents for testing"""
documents = []
for i in range(num_docs):
documents.append(f"Sample document {i} with some random text content for testing purposes.")
return documents
def build_index(documents, output_dir, is_recompute=True):
"""Build HNSW index with specified recompute setting"""
index_path = os.path.join(output_dir, "test_index.hnsw")
builder = LeannBuilder(
backend_name="hnsw",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
M=16,
efConstruction=100,
distance_metric="mips",
is_compact=True,
is_recompute=is_recompute
)
for doc in documents:
builder.add_text(doc)
builder.build_index(index_path)
return index_path
def get_file_size(filepath):
"""Get file size in bytes"""
return os.path.getsize(filepath)
def main():
print("🔍 HNSW Pruning Sanity Check")
print("=" * 50)
# Create sample data
print("📊 Creating sample documents...")
documents = create_sample_documents(num_docs=1000)
print(f" Number of documents: {len(documents)}")
with tempfile.TemporaryDirectory() as temp_dir:
print(f"📁 Working in temporary directory: {temp_dir}")
# Build index with pruning (is_recompute=True)
print("\n🔨 Building index with pruning enabled (is_recompute=True)...")
pruned_dir = os.path.join(temp_dir, "pruned")
os.makedirs(pruned_dir, exist_ok=True)
pruned_index_path = build_index(documents, pruned_dir, is_recompute=True)
# Check what files were actually created
print(f" Looking for index files at: {pruned_index_path}")
import glob
files = glob.glob(f"{pruned_index_path}*")
print(f" Found files: {files}")
# Try to find the actual index file
if os.path.exists(f"{pruned_index_path}.index"):
pruned_index_file = f"{pruned_index_path}.index"
else:
# Look for any .index file in the directory
index_files = glob.glob(f"{pruned_dir}/*.index")
if index_files:
pruned_index_file = index_files[0]
else:
raise FileNotFoundError(f"No .index file found in {pruned_dir}")
pruned_size = get_file_size(pruned_index_file)
print(f" ✅ Pruned index built successfully")
print(f" 📏 Pruned index size: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
# Build index without pruning (is_recompute=False)
print("\n🔨 Building index without pruning (is_recompute=False)...")
non_pruned_dir = os.path.join(temp_dir, "non_pruned")
os.makedirs(non_pruned_dir, exist_ok=True)
non_pruned_index_path = build_index(documents, non_pruned_dir, is_recompute=False)
# Check what files were actually created
print(f" Looking for index files at: {non_pruned_index_path}")
files = glob.glob(f"{non_pruned_index_path}*")
print(f" Found files: {files}")
# Try to find the actual index file
if os.path.exists(f"{non_pruned_index_path}.index"):
non_pruned_index_file = f"{non_pruned_index_path}.index"
else:
# Look for any .index file in the directory
index_files = glob.glob(f"{non_pruned_dir}/*.index")
if index_files:
non_pruned_index_file = index_files[0]
else:
raise FileNotFoundError(f"No .index file found in {non_pruned_dir}")
non_pruned_size = get_file_size(non_pruned_index_file)
print(f" ✅ Non-pruned index built successfully")
print(f" 📏 Non-pruned index size: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
# Compare sizes
print("\n📊 Comparison Results:")
print("=" * 30)
size_diff = non_pruned_size - pruned_size
size_ratio = pruned_size / non_pruned_size if non_pruned_size > 0 else 0
reduction_percent = (1 - size_ratio) * 100
print(f"Non-pruned index: {non_pruned_size:,} bytes ({non_pruned_size/1024:.1f} KB)")
print(f"Pruned index: {pruned_size:,} bytes ({pruned_size/1024:.1f} KB)")
print(f"Size difference: {size_diff:,} bytes ({size_diff/1024:.1f} KB)")
print(f"Size ratio: {size_ratio:.3f}")
print(f"Size reduction: {reduction_percent:.1f}%")
# Verify pruning effectiveness
print("\n🔍 Verification:")
if size_diff > 0:
print(" ✅ Pruning is effective - pruned index is smaller")
if reduction_percent > 10:
print(f" ✅ Significant size reduction: {reduction_percent:.1f}%")
else:
print(f" ⚠️ Small size reduction: {reduction_percent:.1f}%")
else:
print(" ❌ Pruning appears ineffective - no size reduction")
# Check if passages files were created
pruned_passages = f"{pruned_index_path}.passages.json"
non_pruned_passages = f"{non_pruned_index_path}.passages.json"
print(f"\n📄 Passages files:")
print(f" Pruned passages file exists: {os.path.exists(pruned_passages)}")
print(f" Non-pruned passages file exists: {os.path.exists(non_pruned_passages)}")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)