diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index d489970..262fa94 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -10,7 +10,6 @@ import asyncio import os import dotenv from leann.api import LeannBuilder, LeannSearcher, LeannChat -import leann_backend_hnsw # Import to ensure backend registration import shutil from pathlib import Path @@ -39,7 +38,7 @@ all_texts = [] for doc in documents: nodes = node_parser.get_nodes_from_documents([doc]) for node in nodes: - all_texts.append(node.text) + all_texts.append(node.get_content()) INDEX_DIR = Path("./test_pdf_index") INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann") @@ -51,7 +50,7 @@ if not INDEX_DIR.exists(): # CSR compact mode with recompute builder = LeannBuilder( - backend_name="hnsw", + backend_name="diskann", embedding_model="facebook/contriever", graph_degree=32, complexity=64, @@ -74,7 +73,7 @@ async def main(): query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" print(f"You: {query}") - chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True,embedding_model="facebook/contriever") + chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True) print(f"Leann: {chat_response}") if __name__ == "__main__": diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index b1898ff..515560a 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -67,6 +67,26 @@ class DiskannBuilder(LeannBackendBuilderInterface): def __init__(self, **kwargs): self.build_params = kwargs + def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs): + """Generate passages file for recompute mode, mirroring HNSW backend.""" + try: + chunks = kwargs.get('chunks', []) + if not chunks: + print("INFO: No chunks data provided, skipping passages file generation for DiskANN.") + return + + passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)} + + passages_file = index_dir / f"{index_prefix}.passages.json" + with open(passages_file, 'w', encoding='utf-8') as f: + json.dump(passages_data, f, ensure_ascii=False, indent=2) + + print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)") + + except Exception as e: + print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}") + pass + def build(self, data: np.ndarray, index_path: str, **kwargs): path = Path(index_path) index_dir = path.parent @@ -95,6 +115,7 @@ class DiskannBuilder(LeannBackendBuilderInterface): num_threads = build_kwargs.get("num_threads", 8) pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0) codebook_prefix = "" + is_recompute = build_kwargs.get("is_recompute", False) print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...") @@ -113,6 +134,8 @@ class DiskannBuilder(LeannBackendBuilderInterface): codebook_prefix ) print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'") + if is_recompute: + self._generate_passages_file(index_dir, index_prefix, **build_kwargs) except Exception as e: print(f"💥 ERROR: DiskANN index build failed. Exception: {e}") raise @@ -141,17 +164,17 @@ class DiskannSearcher(LeannBackendSearcherInterface): print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") path = Path(index_path) - index_dir = path.parent - index_prefix = path.stem + self.index_dir = path.parent + self.index_prefix = path.stem num_threads = kwargs.get("num_threads", 8) num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) - zmq_port = kwargs.get("zmq_port", 5555) # Get zmq_port from kwargs + self.zmq_port = kwargs.get("zmq_port", 6666) try: - full_index_prefix = str(index_dir / index_prefix) + full_index_prefix = str(self.index_dir / self.index_prefix) self._index = diskannpy.StaticDiskFloatIndex( - metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, zmq_port, "", "" + metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", "" ) self.num_threads = num_threads self.embedding_server_manager = EmbeddingServerManager( @@ -173,23 +196,36 @@ class DiskannSearcher(LeannBackendSearcherInterface): prune_ratio = kwargs.get("prune_ratio", 0.0) batch_recompute = kwargs.get("batch_recompute", False) global_pruning = kwargs.get("global_pruning", False) + port = kwargs.get("zmq_port", self.zmq_port) if recompute_beighbor_embeddings: print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running") if not self.embedding_model: raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.") - - zmq_port = kwargs.get("zmq_port", 6666) - + + passages_file = kwargs.get("passages_file") + if not passages_file: + potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json" + if potential_passages_file.exists(): + passages_file = str(potential_passages_file) + print(f"INFO: Automatically found passages file: {passages_file}") + + if not passages_file: + raise RuntimeError( + f"Recompute mode is enabled, but no passages file was found. " + f"A '{self.index_prefix}.passages.json' file should exist in the index directory " + f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'." + ) + server_started = self.embedding_server_manager.start_server( - port=zmq_port, + port=self.zmq_port, model_name=self.embedding_model, - distance_metric=self.distance_metric + distance_metric=self.distance_metric, + passages_file=passages_file ) if not server_started: - print(f"WARNING: Failed to start embedding server, falling back to PQ computation") - recompute_beighbor_embeddings = False + raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}") if query.dtype != np.float32: query = query.astype(np.float32) diff --git a/packages/leann-core/src/leann/__init__.py b/packages/leann-core/src/leann/__init__.py index e69de29..2f19395 100644 --- a/packages/leann-core/src/leann/__init__.py +++ b/packages/leann-core/src/leann/__init__.py @@ -0,0 +1,7 @@ +# packages/leann-core/src/leann/__init__.py +from .api import LeannBuilder, LeannChat, LeannSearcher +from .registry import BACKEND_REGISTRY, autodiscover_backends + +autodiscover_backends() + +__all__ = ["LeannBuilder", "LeannSearcher", "LeannChat", "BACKEND_REGISTRY"] \ No newline at end of file diff --git a/packages/leann-core/src/leann/registry.py b/packages/leann-core/src/leann/registry.py index 1e6bc72..bda797a 100644 --- a/packages/leann-core/src/leann/registry.py +++ b/packages/leann-core/src/leann/registry.py @@ -1,6 +1,9 @@ # packages/leann-core/src/leann/registry.py from typing import Dict, TYPE_CHECKING +import importlib +import importlib.metadata + if TYPE_CHECKING: from leann.interface import LeannBackendFactoryInterface @@ -12,4 +15,22 @@ def register_backend(name: str): print(f"INFO: Registering backend '{name}'") BACKEND_REGISTRY[name] = cls return cls - return decorator \ No newline at end of file + return decorator + +def autodiscover_backends(): + """Automatically discovers and imports all 'leann-backend-*' packages.""" + print("INFO: Starting backend auto-discovery...") + discovered_backends = [] + for dist in importlib.metadata.distributions(): + dist_name = dist.metadata['name'] + if dist_name.startswith('leann-backend-'): + backend_module_name = dist_name.replace('-', '_') + discovered_backends.append(backend_module_name) + + for backend_module_name in sorted(discovered_backends): # sort for deterministic loading + try: + importlib.import_module(backend_module_name) + # Registration message is printed by the decorator + except ImportError as e: + print(f"WARN: Could not import backend module '{backend_module_name}': {e}") + print("INFO: Backend auto-discovery finished.") \ No newline at end of file