Make DiskANN and HNSW work on main example (#2)

* fix: diskann zmq port and passages

* feat: auto discovery of packages and fix passage gen for diskann
This commit is contained in:
Andy Lee
2025-07-06 13:18:12 +08:00
committed by GitHub
parent a38bc0a3fc
commit cf17c85607
9 changed files with 149 additions and 9653 deletions

3
.gitignore vendored
View File

@@ -71,3 +71,6 @@ test_indices*/
test_*.py test_*.py
!tests/** !tests/**
packages/leann-backend-diskann/third_party/DiskANN/_deps/ packages/leann-backend-diskann/third_party/DiskANN/_deps/
*.meta.json
*.passages.json

View File

@@ -10,7 +10,6 @@ import asyncio
import os import os
import dotenv import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat from leann.api import LeannBuilder, LeannSearcher, LeannChat
import leann_backend_hnsw # Import to ensure backend registration
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -39,7 +38,7 @@ all_texts = []
for doc in documents: for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes: for node in nodes:
all_texts.append(node.text) all_texts.append(node.get_content())
INDEX_DIR = Path("./test_pdf_index") INDEX_DIR = Path("./test_pdf_index")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
@@ -51,7 +50,7 @@ if not INDEX_DIR.exists():
# CSR compact mode with recompute # CSR compact mode with recompute
builder = LeannBuilder( builder = LeannBuilder(
backend_name="hnsw", backend_name="diskann",
embedding_model="facebook/contriever", embedding_model="facebook/contriever",
graph_degree=32, graph_degree=32,
complexity=64, complexity=64,
@@ -74,7 +73,7 @@ async def main():
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
print(f"You: {query}") print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever") chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
print(f"Leann: {chat_response}") print(f"Leann: {chat_response}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, Any
import contextlib import contextlib
import threading import threading
import time import time
@@ -36,7 +36,7 @@ def chdir(path):
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: str): def _write_vectors_to_bin(data: np.ndarray, file_path: Path):
num_vectors, dim = data.shape num_vectors, dim = data.shape
with open(file_path, 'wb') as f: with open(file_path, 'wb') as f:
f.write(struct.pack('I', num_vectors)) f.write(struct.pack('I', num_vectors))
@@ -67,6 +67,26 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode, mirroring HNSW backend."""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation for DiskANN.")
return
passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)}
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}")
pass
def build(self, data: np.ndarray, index_path: str, **kwargs): def build(self, data: np.ndarray, index_path: str, **kwargs):
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
@@ -95,6 +115,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
num_threads = build_kwargs.get("num_threads", 8) num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0) pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = "" codebook_prefix = ""
is_recompute = build_kwargs.get("is_recompute", False)
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...") print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
@@ -113,6 +134,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
codebook_prefix codebook_prefix
) )
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'") print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
if is_recompute:
self._generate_passages_file(index_dir, index_prefix, **build_kwargs)
except Exception as e: except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}") print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise raise
@@ -141,16 +164,17 @@ class DiskannSearcher(LeannBackendSearcherInterface):
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path) path = Path(index_path)
index_dir = path.parent self.index_dir = path.parent
index_prefix = path.stem self.index_prefix = path.stem
num_threads = kwargs.get("num_threads", 8) num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
self.zmq_port = kwargs.get("zmq_port", 6666)
try: try:
full_index_prefix = str(index_dir / index_prefix) full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex( self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", "" metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
) )
self.num_threads = num_threads self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager( self.embedding_server_manager = EmbeddingServerManager(
@@ -161,7 +185,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}") print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]: def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
complexity = kwargs.get("complexity", 256) complexity = kwargs.get("complexity", 256)
beam_width = kwargs.get("beam_width", 4) beam_width = kwargs.get("beam_width", 4)
@@ -172,23 +196,36 @@ class DiskannSearcher(LeannBackendSearcherInterface):
prune_ratio = kwargs.get("prune_ratio", 0.0) prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False) batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False) global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings: if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running") print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_model: if not self.embedding_model:
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.") raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
zmq_port = kwargs.get("zmq_port", 6666) passages_file = kwargs.get("passages_file")
if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Automatically found passages file: {passages_file}")
if not passages_file:
raise RuntimeError(
f"Recompute mode is enabled, but no passages file was found. "
f"A '{self.index_prefix}.passages.json' file should exist in the index directory "
f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'."
)
server_started = self.embedding_server_manager.start_server( server_started = self.embedding_server_manager.start_server(
port=zmq_port, port=self.zmq_port,
model_name=self.embedding_model, model_name=self.embedding_model,
distance_metric=self.distance_metric distance_metric=self.distance_metric,
passages_file=passages_file
) )
if not server_started: if not server_started:
print(f"WARNING: Failed to start embedding server, falling back to PQ computation") raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
recompute_beighbor_embeddings = False
if query.dtype != np.float32: if query.dtype != np.float32:
query = query.astype(np.float32) query = query.astype(np.float32)

View File

@@ -7,6 +7,8 @@ import pickle
import argparse import argparse
import threading import threading
import time import time
import json
from typing import Dict, Any, Optional, Union
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
import os import os
@@ -17,49 +19,49 @@ import numpy as np
RED = "\033[91m" RED = "\033[91m"
RESET = "\033[0m" RESET = "\033[0m"
# 简化的文档存储 - 替代 LazyPassages # --- New Passage Loader from HNSW backend ---
class SimpleDocumentStore: class SimplePassageLoader:
"""简化的文档存储支持任意ID""" """
def __init__(self, documents: dict = None): Simple passage loader that replaces config.py dependencies
self.documents = documents or {} """
# 默认演示文档 def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.default_docs = { self.passages_data = passages_data or {}
0: "Python is a high-level, interpreted language known for simplicity.",
1: "Machine learning builds systems that learn from data.",
2: "Data structures like arrays, lists, and graphs organize data.",
}
def __getitem__(self, doc_id): def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
doc_id = int(doc_id) """Get passage by ID"""
str_id = str(passage_id)
# 优先使用指定的文档 if str_id in self.passages_data:
if doc_id in self.documents: return {"text": self.passages_data[str_id]}
return {"text": self.documents[doc_id]} else:
# Return empty text for missing passages
# 其次使用默认演示文档 return {"text": ""}
if doc_id in self.default_docs:
return {"text": self.default_docs[doc_id]}
# 对于任意其他ID返回通用文档
fallback_docs = [
"This is a general document about technology and programming concepts.",
"This document discusses machine learning and artificial intelligence topics.",
"This content covers data structures, algorithms, and computer science fundamentals.",
"This is a document about software engineering and development practices.",
"This content focuses on databases, data management, and information systems."
]
# 根据ID选择一个fallback文档
fallback_text = fallback_docs[doc_id % len(fallback_docs)]
return {"text": f"[ID:{doc_id}] {fallback_text}"}
def __len__(self): def __len__(self) -> int:
return len(self.documents) + len(self.default_docs) return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSON file
Expected format: {"passage_id": "passage_text", ...}
"""
if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
return SimplePassageLoader()
try:
with open(passages_file, 'r', encoding='utf-8') as f:
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}")
return SimplePassageLoader(passages_data)
except Exception as e:
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader()
def create_embedding_server_thread( def create_embedding_server_thread(
zmq_port=5555, zmq_port=5555,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
max_batch_size=128, max_batch_size=128,
passages_file: Optional[str] = None,
): ):
""" """
在当前线程中创建并运行 embedding server 在当前线程中创建并运行 embedding server
@@ -109,15 +111,14 @@ def create_embedding_server_thread(
except Exception as e: except Exception as e:
print(f"WARNING: Model optimization failed: {e}") print(f"WARNING: Model optimization failed: {e}")
# 默认演示文档 # Load passages from file if provided
demo_documents = { if passages_file and os.path.exists(passages_file):
0: "Python is a high-level, interpreted language known for simplicity.", passages = load_passages_from_file(passages_file)
1: "Machine learning builds systems that learn from data.", else:
2: "Data structures like arrays, lists, and graphs organize data.", print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
} passages = SimplePassageLoader()
passages = SimpleDocumentStore(demo_documents) print(f"INFO: Loaded {len(passages)} passages.")
print(f"INFO: Loaded {len(passages)} demo documents")
class DeviceTimer: class DeviceTimer:
"""设备计时器""" """设备计时器"""
@@ -264,7 +265,13 @@ def create_embedding_server_thread(
for nid in node_ids: for nid in node_ids:
txtinfo = passages[nid] txtinfo = passages[nid]
txt = txtinfo["text"] txt = txtinfo["text"]
texts.append(txt) if txt:
texts.append(txt)
else:
# 如果文本为空,我们仍然需要一个占位符来进行批处理,
# 但将其ID记录为缺失
texts.append("")
missing_ids.append(nid)
lookup_timer.print_elapsed() lookup_timer.print_elapsed()
if missing_ids: if missing_ids:
@@ -360,18 +367,20 @@ def create_embedding_server(
max_batch_size=128, max_batch_size=128,
lazy_load_passages=False, lazy_load_passages=False,
model_name="sentence-transformers/all-mpnet-base-v2", model_name="sentence-transformers/all-mpnet-base-v2",
passages_file: Optional[str] = None,
): ):
""" """
原有的 create_embedding_server 函数保持不变 原有的 create_embedding_server 函数保持不变
这个是阻塞版本,用于直接运行 这个是阻塞版本,用于直接运行
""" """
create_embedding_server_thread(zmq_port, model_name, max_batch_size) create_embedding_server_thread(zmq_port, model_name, max_batch_size, passages_file)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Embedding service") parser = argparse.ArgumentParser(description="Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on") parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--domain", type=str, default="demo", help="Domain name") parser.add_argument("--domain", type=str, default="demo", help="Domain name")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--load-passages", action="store_true", default=True) parser.add_argument("--load-passages", action="store_true", default=True)
parser.add_argument("--load-embeddings", action="store_true", default=False) parser.add_argument("--load-embeddings", action="store_true", default=False)
parser.add_argument("--use-fp16", action="store_true", default=False) parser.add_argument("--use-fp16", action="store_true", default=False)
@@ -394,4 +403,5 @@ if __name__ == "__main__":
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
lazy_load_passages=args.lazy_load_passages, lazy_load_passages=args.lazy_load_passages,
model_name=args.model_name, model_name=args.model_name,
passages_file=args.passages_file,
) )

View File

@@ -303,7 +303,8 @@ class HNSWSearcher(LeannBackendSearcherInterface):
hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False) hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False)
hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0) hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = kwargs.get("external_storage_path") hnsw_config.external_storage_path = kwargs.get("external_storage_path")
hnsw_config.zmq_port = kwargs.get("zmq_port", 5557)
self.zmq_port = kwargs.get("zmq_port", 5557)
if self.is_pruned and not hnsw_config.is_recompute: if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.") raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
@@ -361,13 +362,15 @@ class HNSWSearcher(LeannBackendSearcherInterface):
faiss.normalize_L2(query) faiss.normalize_L2(query)
try: try:
self._index.hnsw.efSearch = ef params = faiss.SearchParametersHNSW()
params.efSearch = ef
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
batch_size = query.shape[0] batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32) distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64) labels = np.empty((batch_size, top_k), dtype=np.int64)
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels)) self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
return {"labels": labels, "distances": distances} return {"labels": labels, "distances": distances}

View File

@@ -0,0 +1,7 @@
# packages/leann-core/src/leann/__init__.py
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends
autodiscover_backends()
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]

View File

@@ -1,6 +1,9 @@
# packages/leann-core/src/leann/registry.py # packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING from typing import Dict, TYPE_CHECKING
import importlib
import importlib.metadata
if TYPE_CHECKING: if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface from leann.interface import LeannBackendFactoryInterface
@@ -12,4 +15,22 @@ def register_backend(name: str):
print(f"INFO: Registering backend '{name}'") print(f"INFO: Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls BACKEND_REGISTRY[name] = cls
return cls return cls
return decorator return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata['name']
if dist_name.startswith('leann-backend-'):
backend_module_name = dist_name.replace('-', '_')
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
print("INFO: Backend auto-discovery finished.")

View File

File diff suppressed because it is too large Load Diff