feat: auto discovery of packages and fix passage gen for diskann

This commit is contained in:
Andy Lee
2025-07-06 05:05:49 +00:00
parent 5659174635
commit b4ae57b2c0
4 changed files with 80 additions and 17 deletions

View File

@@ -10,7 +10,6 @@ import asyncio
import os
import dotenv
from leann.api import LeannBuilder, LeannSearcher, LeannChat
import leann_backend_hnsw # Import to ensure backend registration
import shutil
from pathlib import Path
@@ -39,7 +38,7 @@ all_texts = []
for doc in documents:
nodes = node_parser.get_nodes_from_documents([doc])
for node in nodes:
all_texts.append(node.text)
all_texts.append(node.get_content())
INDEX_DIR = Path("./test_pdf_index")
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
@@ -51,7 +50,7 @@ if not INDEX_DIR.exists():
# CSR compact mode with recompute
builder = LeannBuilder(
backend_name="hnsw",
backend_name="diskann",
embedding_model="facebook/contriever",
graph_degree=32,
complexity=64,
@@ -74,7 +73,7 @@ async def main():
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
print(f"You: {query}")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever")
chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True)
print(f"Leann: {chat_response}")
if __name__ == "__main__":

View File

@@ -67,6 +67,26 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode, mirroring HNSW backend."""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation for DiskANN.")
return
passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)}
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}")
pass
def build(self, data: np.ndarray, index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
@@ -95,6 +115,7 @@ class DiskannBuilder(LeannBackendBuilderInterface):
num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = ""
is_recompute = build_kwargs.get("is_recompute", False)
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
@@ -113,6 +134,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
codebook_prefix
)
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
if is_recompute:
self._generate_passages_file(index_dir, index_prefix, **build_kwargs)
except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise
@@ -141,17 +164,17 @@ class DiskannSearcher(LeannBackendSearcherInterface):
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
self.index_dir = path.parent
self.index_prefix = path.stem
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
zmq_port = kwargs.get("zmq_port", 5555) # Get zmq_port from kwargs
self.zmq_port = kwargs.get("zmq_port", 6666)
try:
full_index_prefix = str(index_dir / index_prefix)
full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, zmq_port, "", ""
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
)
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager(
@@ -173,23 +196,36 @@ class DiskannSearcher(LeannBackendSearcherInterface):
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
port = kwargs.get("zmq_port", self.zmq_port)
if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_model:
raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.")
zmq_port = kwargs.get("zmq_port", 6666)
passages_file = kwargs.get("passages_file")
if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Automatically found passages file: {passages_file}")
if not passages_file:
raise RuntimeError(
f"Recompute mode is enabled, but no passages file was found. "
f"A '{self.index_prefix}.passages.json' file should exist in the index directory "
f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'."
)
server_started = self.embedding_server_manager.start_server(
port=zmq_port,
port=self.zmq_port,
model_name=self.embedding_model,
distance_metric=self.distance_metric
distance_metric=self.distance_metric,
passages_file=passages_file
)
if not server_started:
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
recompute_beighbor_embeddings = False
raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}")
if query.dtype != np.float32:
query = query.astype(np.float32)

View File

@@ -0,0 +1,7 @@
# packages/leann-core/src/leann/__init__.py
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import BACKEND_REGISTRY, autodiscover_backends
autodiscover_backends()
__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"]

View File

@@ -1,6 +1,9 @@
# packages/leann-core/src/leann/registry.py
from typing import Dict, TYPE_CHECKING
import importlib
import importlib.metadata
if TYPE_CHECKING:
from leann.interface import LeannBackendFactoryInterface
@@ -12,4 +15,22 @@ def register_backend(name: str):
print(f"INFO: Registering backend '{name}'")
BACKEND_REGISTRY[name] = cls
return cls
return decorator
return decorator
def autodiscover_backends():
"""Automatically discovers and imports all 'leann-backend-*' packages."""
print("INFO: Starting backend auto-discovery...")
discovered_backends = []
for dist in importlib.metadata.distributions():
dist_name = dist.metadata['name']
if dist_name.startswith('leann-backend-'):
backend_module_name = dist_name.replace('-', '_')
discovered_backends.append(backend_module_name)
for backend_module_name in sorted(discovered_backends): # sort for deterministic loading
try:
importlib.import_module(backend_module_name)
# Registration message is printed by the decorator
except ImportError as e:
print(f"WARN: Could not import backend module '{backend_module_name}': {e}")
print("INFO: Backend auto-discovery finished.")