Make DiskANN and HNSW work on main example (#2)

* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann
This commit is contained in:
Andy Lee
2025-07-06 13:18:12 +08:00
committed by GitHub
parent a38bc0a3fc
commit cf17c85607
9 changed files with 149 additions and 9653 deletions

3
.gitignore vendored
View File

@@ -71,3 +71,6 @@ test_indices*/
test_*.py
!tests/**
packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json

View File

@@ -10,7 +10,6 @@ import asyncio
import os
import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat
import leann_backend_hnsw # Import to ensure backend registration
import shutil
from pathlib import Path
@@ -39,7 +38,7 @@ all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.text)
all_texts.append(node.get_content())
INDEX_DIR = Path("./test_pdf_index")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
@@ -51,7 +50,7 @@ if not INDEX_DIR.exists():
# CSR compact mode with recompute
builder = LeannBuilder(
backend_name="hnsw",
backend_name="diskann",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
@@ -74,7 +73,7 @@ async def main():
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
print(f"Leann: {chat_response}")
if __name__ == "__main__":

View File

@@ -3,7 +3,7 @@ import os
import json
import struct
from pathlib import Path
from typing import Dict
from typing import Dict, Any
import contextlib
import threading
import time
@@ -36,7 +36,7 @@ def chdir(path):
finally:
os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
num_vectors, dim = data.shape
with open(file_path, 'wb') as f:
f.write(struct.pack('I', num_vectors))
@@ -67,6 +67,26 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode, mirroring HNSW backend."""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation for DiskANN.")
return
passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)}
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}")
pass
def build(self, data: np.ndarray, index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
@@ -95,6 +115,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = ""
is_recompute = build_kwargs.get("is_recompute", False)
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
@@ -113,6 +134,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
codebook_prefix
)
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
if is_recompute:
self._generate_passages_file(index_dir, index_prefix, **build_kwargs)
except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise
@@ -141,16 +164,17 @@ class DiskannSearcher(LeannBackendSearcherInterface):
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
self.index_dir = path.parent
self.index_prefix = path.stem
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
self.zmq_port = kwargs.get("zmq_port", 6666)
try:
full_index_prefix = str(index_dir / index_prefix)
full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
)
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager(
@@ -161,7 +185,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256)
beam_width = kwargs.get("beam_width", 4)
@@ -172,23 +196,36 @@ class DiskannSearcher(LeannBackendSearcherInterface):
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_model:
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
zmq_port = kwargs.get("zmq_port", 6666)
passages_file = kwargs.get("passages_file")
if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Automatically found passages file: {passages_file}")
if not passages_file:
raise RuntimeError(
f"Recompute mode is enabled, but no passages file was found. "
f"A '{self.index_prefix}.passages.json' file should exist in the index directory "
f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'."
)
server_started = self.embedding_server_manager.start_server(
port=zmq_port,
port=self.zmq_port,
model_name=self.embedding_model,
distance_metric=self.distance_metric
distance_metric=self.distance_metric,
passages_file=passages_file
)
if not server_started:
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
recompute_beighbor_embeddings = False
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
if query.dtype != np.float32:
query = query.astype(np.float32)

View File

@@ -7,6 +7,8 @@ import pickle
import argparse
import threading
import time
import json
from typing import Dict, Any, Optional, Union
from transformers import AutoTokenizer, AutoModel
import os
@@ -17,49 +19,49 @@ import numpy as np
RED = "\033[91m"
RESET = "\033[0m"
# 简化的文档存储 - 替代 LazyPassages
class SimpleDocumentStore:
"""简化的文档存储支持任意ID"""
def __init__(self, documents: dict = None):
self.documents = documents or {}
# 默认演示文档
self.default_docs = {
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
# --- New Passage Loader from HNSW backend ---
class SimplePassageLoader:
"""
Simple passage loader that replaces config.py dependencies
"""
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {}
def __getitem__(self, doc_id):
doc_id = int(doc_id)
# 优先使用指定的文档
if doc_id in self.documents:
return {"text": self.documents[doc_id]}
# 其次使用默认演示文档
if doc_id in self.default_docs:
return {"text": self.default_docs[doc_id]}
# 对于任意其他ID返回通用文档
fallback_docs = [
"This is a general document about technology and programming concepts.",
"This document discusses machine learning and artificial intelligence topics.",
"This content covers data structures, algorithms, and computer science fundamentals.",
"This is a document about software engineering and development practices.",
"This content focuses on databases, data management, and information systems."
]
# 根据ID选择一个fallback文档
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
return {"text": f"[ID:{doc_id}] {fallback_text}"}
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID"""
str_id = str(passage_id)
if str_id in self.passages_data:
return {"text": self.passages_data[str_id]}
else:
# Return empty text for missing passages
return {"text": ""}
def __len__(self):
return len(self.documents) + len(self.default_docs)
def __len__(self) -> int:
return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSON file
Expected format: {"passage_id": "passage_text", ...}
"""
if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
return SimplePassageLoader()
try:
with open(passages_file, 'r', encoding='utf-8') as f:
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}")
return SimplePassageLoader(passages_data)
except Exception as e:
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader()
def create_embedding_server_thread(
zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128,
passages_file: Optional[str] = None,
):
"""
在当前线程中创建并运行 embedding server
@@ -109,15 +111,14 @@ def create_embedding_server_thread(
except Exception as e:
print(f"WARNING: Model optimization failed: {e}")
# 默认演示文档
demo_documents = {
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
# Load passages from file if provided
if passages_file and os.path.exists(passages_file):
passages = load_passages_from_file(passages_file)
else:
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()
passages = SimpleDocumentStore(demo_documents)
print(f"INFO: Loaded {len(passages)} demo documents")
print(f"INFO: Loaded {len(passages)} passages.")
class DeviceTimer:
"""设备计时器"""
@@ -264,7 +265,13 @@ def create_embedding_server_thread(
for nid in node_ids:
txtinfo = passages[nid]
txt = txtinfo["text"]
texts.append(txt)
if txt:
texts.append(txt)
else:
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
# 但将其ID记录为缺失
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed()
if missing_ids:
@@ -360,18 +367,20 @@ def create_embedding_server(
max_batch_size=128,
lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
):
"""
原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行
"""
create_embedding_server_thread(zmq_port, model_name, max_batch_size)
create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False)
@@ -394,4 +403,5 @@ if __name__ == "__main__":
max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name,
passages_file=args.passages_file,
)

View File

@@ -303,7 +303,8 @@ class HNSWSearcher(LeannBackendSearcherInterface):
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = kwargs.get("external_storage_path")
hnsw_config.zmq_port = kwargs.get("zmq_port", 5557)
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
@@ -361,13 +362,15 @@ class HNSWSearcher(LeannBackendSearcherInterface):
faiss.normalize_L2(query)
try:
self._index.hnsw.efSearch = ef
params = faiss.SearchParametersHNSW()
params.efSearch = ef
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
return {"labels": labels, "distances": distances}

View File

@@ -0,0 +1,7 @@
# packages/leann-core/src/leann/__init__.py
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends
autodiscover_backends()
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]

View File

@@ -1,6 +1,9 @@
# packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING
import importlib
import importlib.metadata
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
@@ -12,4 +15,22 @@ def register_backend(name: str):
print(f"INFO: Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls
return cls
return decorator
return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata['name']
if dist_name.startswith('leann-backend-'):
backend_module_name = dist_name.replace('-', '_')
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
print("INFO: Backend auto-discovery finished.")

View File

File diff suppressed because it is too large Load Diff