diff --git a/examples/main_cli_example.py b/examples/main_cli_example.py index 262fa94..a54feaa 100644 --- a/examples/main_cli_example.py +++ b/examples/main_cli_example.py @@ -1,6 +1,7 @@ import faulthandler faulthandler.enable() +import argparse from llama_index.core import SimpleDirectoryReader, Settings from llama_index.core.readers.base import BaseReader from llama_index.node_parser.docling import DoclingNodeParser @@ -50,7 +51,7 @@ if not INDEX_DIR.exists(): # CSR compact mode with recompute builder = LeannBuilder( - backend_name="diskann", + backend_name="hnsw", embedding_model="facebook/contriever", graph_degree=32, complexity=64, @@ -67,14 +68,27 @@ if not INDEX_DIR.exists(): else: print(f"--- Using existing index at {INDEX_DIR} ---") -async def main(): +async def main(args): print(f"\n[PHASE 2] Starting Leann chat session...") - chat = LeannChat(index_path=INDEX_PATH) + + llm_config = { + "type": args.llm, + "model": args.model, + "host": args.host + } + + chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config) query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?" print(f"You: {query}") - chat_response = chat.ask(query, top_k=20, recompute_beighbor_embeddings=True) + chat_response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True) print(f"Leann: {chat_response}") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + parser = argparse.ArgumentParser(description="Run Leann Chat with various LLM backends.") + parser.add_argument("--llm", type=str, default="hf", choices=["simulated", "ollama", "hf"], help="The LLM backend to use.") + parser.add_argument("--model", type=str, default='meta-llama/Llama-3.2-3B-Instruct', help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf).") + parser.add_argument("--host", type=str, default="http://localhost:11434", help="The host for the Ollama API.") + args = parser.parse_args() + + asyncio.run(main(args)) \ No newline at end of file diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index cd982e1..e31f912 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -5,21 +5,16 @@ import struct from pathlib import Path from typing import Dict, Any, List import contextlib -import threading -import time -import atexit -import socket -import subprocess -import sys import pickle -from leann.embedding_server_manager import EmbeddingServerManager +from leann.searcher_base import BaseSearcher from leann.registry import register_backend from leann.interface import ( LeannBackendFactoryInterface, LeannBackendBuilderInterface, LeannBackendSearcherInterface ) + def _get_diskann_metrics(): from . import _diskannpy as diskannpy return { @@ -52,211 +47,87 @@ class DiskannBackend(LeannBackendFactoryInterface): @staticmethod def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: - path = Path(index_path) - meta_path = path.parent / f"{path.name}.meta.json" - if not meta_path.exists(): - raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.") - - with open(meta_path, 'r') as f: - meta = json.load(f) - - # Pass essential metadata to the searcher - kwargs['meta'] = meta return DiskannSearcher(index_path, **kwargs) class DiskannBuilder(LeannBackendBuilderInterface): def __init__(self, **kwargs): self.build_params = kwargs - def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): path = Path(index_path) index_dir = path.parent index_prefix = path.stem - index_dir.mkdir(parents=True, exist_ok=True) if data.dtype != np.float32: data = data.astype(np.float32) - if not data.flags['C_CONTIGUOUS']: - data = np.ascontiguousarray(data) - + data_filename = f"{index_prefix}_data.bin" _write_vectors_to_bin(data, index_dir / data_filename) - # Create label map: integer -> string_id label_map = {i: str_id for i, str_id in enumerate(ids)} label_map_file = index_dir / "leann.labels.map" with open(label_map_file, 'wb') as f: pickle.dump(label_map, f) build_kwargs = {**self.build_params, **kwargs} - metric_str = build_kwargs.get("distance_metric", "mips").lower() - METRIC_MAP = _get_diskann_metrics() - metric_enum = METRIC_MAP.get(metric_str) + metric_enum = _get_diskann_metrics().get(build_kwargs.get("distance_metric", "mips").lower()) if metric_enum is None: - raise ValueError(f"Unsupported distance_metric '{metric_str}'.") + raise ValueError(f"Unsupported distance_metric.") - complexity = build_kwargs.get("complexity", 64) - graph_degree = build_kwargs.get("graph_degree", 32) - final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0) - indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0) - num_threads = build_kwargs.get("num_threads", 8) - pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0) - codebook_prefix = "" - - print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...") - try: from . import _diskannpy as diskannpy with chdir(index_dir): diskannpy.build_disk_float_index( - metric_enum, - data_filename, - index_prefix, - complexity, - graph_degree, - final_index_ram_limit, - indexing_ram_budget, - num_threads, - pq_disk_bytes, - codebook_prefix + metric_enum, data_filename, index_prefix, + build_kwargs.get("complexity", 64), build_kwargs.get("graph_degree", 32), + build_kwargs.get("search_memory_maximum", 4.0), build_kwargs.get("build_memory_maximum", 8.0), + build_kwargs.get("num_threads", 8), build_kwargs.get("pq_disk_bytes", 0), "" ) - print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'") - except Exception as e: - print(f"💥 ERROR: DiskANN index build failed. Exception: {e}") - raise finally: temp_data_file = index_dir / data_filename if temp_data_file.exists(): os.remove(temp_data_file) -class DiskannSearcher(LeannBackendSearcherInterface): +class DiskannSearcher(BaseSearcher): def __init__(self, index_path: str, **kwargs): - self.meta = kwargs.get("meta", {}) - if not self.meta: - raise ValueError("DiskannSearcher requires metadata from .meta.json.") + super().__init__(index_path, backend_module_name="leann_backend_diskann.embedding_server", **kwargs) + from . import _diskannpy as diskannpy - self.embedding_model = self.meta.get("embedding_model") - if not self.embedding_model: - print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") - - self.index_path = Path(index_path) - self.index_dir = self.index_path.parent - self.index_prefix = self.index_path.stem - - # Load the label map - label_map_file = self.index_dir / "leann.labels.map" - if not label_map_file.exists(): - raise FileNotFoundError(f"Label map file not found: {label_map_file}") - - with open(label_map_file, 'rb') as f: - self.label_map = pickle.load(f) - - # Extract parameters for DiskANN distance_metric = kwargs.get("distance_metric", "mips").lower() - METRIC_MAP = _get_diskann_metrics() - metric_enum = METRIC_MAP.get(distance_metric) + metric_enum = _get_diskann_metrics().get(distance_metric) if metric_enum is None: raise ValueError(f"Unsupported distance_metric '{distance_metric}'.") - - num_threads = kwargs.get("num_threads", 8) - num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) + + self.num_threads = kwargs.get("num_threads", 8) self.zmq_port = kwargs.get("zmq_port", 6666) - - try: - from . import _diskannpy as diskannpy - full_index_prefix = str(self.index_dir / self.index_prefix) - self._index = diskannpy.StaticDiskFloatIndex( - metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", "" - ) - self.num_threads = num_threads - self.embedding_server_manager = EmbeddingServerManager( - backend_module_name="leann_backend_diskann.embedding_server" - ) - print("✅ DiskANN index loaded successfully.") - except Exception as e: - print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}") - raise + + full_index_prefix = str(self.index_dir / self.index_path.stem) + self._index = diskannpy.StaticDiskFloatIndex( + metric_enum, full_index_prefix, self.num_threads, + kwargs.get("num_nodes_to_cache", 0), 1, self.zmq_port, "", "" + ) def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: - complexity = kwargs.get("complexity", 256) - beam_width = kwargs.get("beam_width", 4) - - USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False) - skip_search_reorder = kwargs.get("skip_search_reorder", False) - recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False) - dedup_node_dis = kwargs.get("dedup_node_dis", False) - prune_ratio = kwargs.get("prune_ratio", 0.0) - batch_recompute = kwargs.get("batch_recompute", False) - global_pruning = kwargs.get("global_pruning", False) - port = kwargs.get("zmq_port", self.zmq_port) - - if recompute_beighbor_embeddings: - print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running") - if not self.embedding_model: - raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.") + recompute = kwargs.get("recompute_beighbor_embeddings", False) + if recompute: + meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" + if not meta_file_path.exists(): + raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}") + zmq_port = kwargs.get("zmq_port", self.zmq_port) + self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) - passages_file = kwargs.get("passages_file") - if not passages_file: - # Pass the metadata file instead of a single passage file - meta_file_path = self.index_path.parent / f"{self.index_path.name}.meta.json" - if meta_file_path.exists(): - passages_file = str(meta_file_path) - print(f"INFO: Using metadata file for lazy loading: {passages_file}") - else: - raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}") - - server_started = self.embedding_server_manager.start_server( - port=self.zmq_port, - model_name=self.embedding_model, - distance_metric=kwargs.get("distance_metric", "mips"), - passages_file=passages_file - ) - - if not server_started: - raise RuntimeError(f"Failed to start DiskANN embedding server on port {self.zmq_port}") - if query.dtype != np.float32: query = query.astype(np.float32) - if query.ndim == 1: - query = np.expand_dims(query, axis=0) - - try: - labels, distances = self._index.batch_search( - query, - query.shape[0], - top_k, - complexity, - beam_width, - self.num_threads, - USE_DEFERRED_FETCH, - skip_search_reorder, - recompute_beighbor_embeddings, - dedup_node_dis, - prune_ratio, - batch_recompute, - global_pruning - ) - - # Convert integer labels to string IDs - string_labels = [] - for batch_labels in labels: - batch_string_labels = [] - for int_label in batch_labels: - if int_label in self.label_map: - batch_string_labels.append(self.label_map[int_label]) - else: - batch_string_labels.append(f"unknown_{int_label}") - string_labels.append(batch_string_labels) - - return {"labels": string_labels, "distances": distances} - except Exception as e: - print(f"💥 ERROR: DiskANN search failed. Exception: {e}") - batch_size = query.shape[0] - return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)], - "distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)} - - def __del__(self): - if hasattr(self, 'embedding_server_manager'): - self.embedding_server_manager.stop_server() + + labels, distances = self._index.batch_search( + query, query.shape[0], top_k, + kwargs.get("complexity", 256), kwargs.get("beam_width", 4), self.num_threads, + kwargs.get("USE_DEFERRED_FETCH", False), kwargs.get("skip_search_reorder", False), + recompute, kwargs.get("dedup_node_dis", False), kwargs.get("prune_ratio", 0.0), + kwargs.get("batch_recompute", False), kwargs.get("global_pruning", False) + ) + + string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels] + + return {"labels": string_labels, "distances": distances} \ No newline at end of file diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index 126a537..229819a 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -3,16 +3,9 @@ import os import json from pathlib import Path from typing import Dict, Any, List -import contextlib -import threading -import time -import atexit -import socket -import subprocess -import sys import pickle -from leann.embedding_server_manager import EmbeddingServerManager +from leann.searcher_base import BaseSearcher from .convert_to_csr import convert_hnsw_graph_to_csr from leann.registry import register_backend @@ -38,306 +31,120 @@ class HNSWBackend(LeannBackendFactoryInterface): @staticmethod def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface: - path = Path(index_path) - meta_path = path.parent / f"{path.name}.meta.json" - if not meta_path.exists(): - raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.") - - with open(meta_path, 'r') as f: - meta = json.load(f) - - kwargs['meta'] = meta return HNSWSearcher(index_path, **kwargs) class HNSWBuilder(LeannBackendBuilderInterface): def __init__(self, **kwargs): self.build_params = kwargs.copy() - - # --- Configuration defaults with standardized names --- self.is_compact = self.build_params.setdefault("is_compact", True) self.is_recompute = self.build_params.setdefault("is_recompute", True) - - # --- Additional Options --- - self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False) - self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0) - self.external_storage_path = self.build_params.get("external_storage_path", None) - - # --- Standard HNSW parameters --- self.M = self.build_params.setdefault("M", 32) self.efConstruction = self.build_params.setdefault("efConstruction", 200) self.distance_metric = self.build_params.setdefault("distance_metric", "mips") self.dimensions = self.build_params.get("dimensions") - if self.is_skip_neighbors and not self.is_compact: - raise ValueError("is_skip_neighbors can only be used with is_compact=True") - - if self.is_recompute and not self.is_compact: - raise ValueError("is_recompute requires is_compact=True for efficiency") - def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): - """Build HNSW index using FAISS""" from . import faiss - path = Path(index_path) index_dir = path.parent index_prefix = path.stem - index_dir.mkdir(parents=True, exist_ok=True) if data.dtype != np.float32: data = data.astype(np.float32) - if not data.flags['C_CONTIGUOUS']: - data = np.ascontiguousarray(data) - - # Create label map: integer -> string_id + label_map = {i: str_id for i, str_id in enumerate(ids)} label_map_file = index_dir / "leann.labels.map" with open(label_map_file, 'wb') as f: pickle.dump(label_map, f) - - metric_str = self.distance_metric.lower() - metric_enum = get_metric_map().get(metric_str) + + metric_enum = get_metric_map().get(self.distance_metric.lower()) if metric_enum is None: - raise ValueError(f"Unsupported distance_metric '{metric_str}'.") + raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") - M = self.M - efConstruction = self.efConstruction - dim = self.dimensions - if not dim: - dim = data.shape[1] + dim = self.dimensions or data.shape[1] + index = faiss.IndexHNSWFlat(dim, self.M, metric_enum) + index.hnsw.efConstruction = self.efConstruction - print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...") - - try: - index = faiss.IndexHNSWFlat(dim, M, metric_enum) - index.hnsw.efConstruction = efConstruction - - if metric_str == "cosine": - faiss.normalize_L2(data) - - index.add(data.shape[0], faiss.swig_ptr(data)) - - index_file = index_dir / f"{index_prefix}.index" - faiss.write_index(index, str(index_file)) - - print(f"✅ HNSW index built successfully at '{index_file}'") + if self.distance_metric.lower() == "cosine": + faiss.normalize_L2(data) - if self.is_compact: - self._convert_to_csr(index_file) - - except Exception as e: - print(f"💥 ERROR: HNSW index build failed. Exception: {e}") - raise + index.add(data.shape[0], faiss.swig_ptr(data)) + index_file = index_dir / f"{index_prefix}.index" + faiss.write_index(index, str(index_file)) + + if self.is_compact: + self._convert_to_csr(index_file) def _convert_to_csr(self, index_file: Path): - """Convert built index to CSR format""" - try: - mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard" - print(f"INFO: Converting HNSW index to {mode_str} format...") - - csr_temp_file = index_file.with_suffix(".csr.tmp") - - success = convert_hnsw_graph_to_csr( - str(index_file), - str(csr_temp_file), - prune_embeddings=self.is_recompute - ) - - if success: - print("✅ CSR conversion successful.") - import shutil - shutil.move(str(csr_temp_file), str(index_file)) - print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'") - else: - # Clean up and fail fast - if csr_temp_file.exists(): - os.remove(csr_temp_file) - raise RuntimeError("CSR conversion failed - cannot proceed with compact format") - - except Exception as e: - print(f"💥 ERROR: CSR conversion failed. Exception: {e}") - raise - - -class HNSWSearcher(LeannBackendSearcherInterface): - def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]: - """ - Get storage status from metadata with sensible defaults. - - Returns: - A tuple (is_compact, is_pruned). - """ - # Check if metadata has these flags - is_compact = self.meta.get('is_compact', True) # Default to compact (CSR format) - is_pruned = self.meta.get('is_pruned', True) # Default to pruned (embeddings removed) - - print(f"INFO: Storage status from metadata: is_compact={is_compact}, is_pruned={is_pruned}") - return is_compact, is_pruned + csr_temp_file = index_file.with_suffix(".csr.tmp") + success = convert_hnsw_graph_to_csr( + str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute + ) + if success: + import shutil + shutil.move(str(csr_temp_file), str(index_file)) + else: + if csr_temp_file.exists(): + os.remove(csr_temp_file) + raise RuntimeError("CSR conversion failed") +class HNSWSearcher(BaseSearcher): def __init__(self, index_path: str, **kwargs): + super().__init__(index_path, backend_module_name="leann_backend_hnsw.hnsw_embedding_server", **kwargs) from . import faiss - self.meta = kwargs.get("meta", {}) - if not self.meta: - raise ValueError("HNSWSearcher requires metadata from .meta.json.") - self.dimensions = self.meta.get("dimensions") - if not self.dimensions: - raise ValueError("Dimensions not found in Leann metadata.") - self.distance_metric = self.meta.get("distance_metric", "mips").lower() metric_enum = get_metric_map().get(self.distance_metric) if metric_enum is None: raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") - self.embedding_model = self.meta.get("embedding_model") - if not self.embedding_model: - print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") + self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta() - # Check for embedding model override (not allowed) - if 'embedding_model' in kwargs and kwargs['embedding_model'] != self.embedding_model: - raise ValueError(f"Embedding model override not allowed. Index uses '{self.embedding_model}', but got '{kwargs['embedding_model']}'") - - path = Path(index_path) - self.index_dir = path.parent - self.index_prefix = path.stem - - # Load the label map - label_map_file = self.index_dir / "leann.labels.map" - if not label_map_file.exists(): - raise FileNotFoundError(f"Label map file not found: {label_map_file}") - - with open(label_map_file, 'rb') as f: - self.label_map = pickle.load(f) - - index_file = self.index_dir / f"{self.index_prefix}.index" + index_file = self.index_dir / f"{self.index_path.stem}.index" if not index_file.exists(): raise FileNotFoundError(f"HNSW index file not found at {index_file}") - # Get storage status from metadata with user overrides - self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta() - - # Allow override of storage parameters via kwargs - if 'is_compact' in kwargs: - self.is_compact = kwargs['is_compact'] - if 'is_pruned' in kwargs: - self.is_pruned = kwargs['is_pruned'] - - # Validate configuration constraints - if not self.is_compact and kwargs.get("is_skip_neighbors", False): - raise ValueError("is_skip_neighbors can only be used with is_compact=True") - - if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"): - raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously") - hnsw_config = faiss.HNSWIndexConfig() hnsw_config.is_compact = self.is_compact - - # Apply additional configuration options with strict validation - hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False) hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False) - hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0) - hnsw_config.external_storage_path = kwargs.get("external_storage_path") - - self.zmq_port = kwargs.get("zmq_port", 5557) - - if self.is_pruned and not hnsw_config.is_recompute: - raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.") - - print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}") - print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}") - - self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) - - if self.is_compact: - print("✅ Compact CSR format HNSW index loaded successfully.") - else: - print("✅ Standard HNSW index loaded successfully.") - self.embedding_server_manager = EmbeddingServerManager( - backend_module_name="leann_backend_hnsw.hnsw_embedding_server" - ) + if self.is_pruned and not hnsw_config.is_recompute: + raise RuntimeError("Index is pruned but recompute is disabled.") + + self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config) + + def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]: + is_compact = self.meta.get('is_compact', True) + is_pruned = self.meta.get('is_pruned', True) + return is_compact, is_pruned def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: - """Search using HNSW index with optional recompute functionality""" from . import faiss - - ef = kwargs.get("ef", 128) - + if self.is_pruned: - print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.") - if not self.embedding_model: - raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.") - - passages_file = kwargs.get("passages_file") - if not passages_file: - # Pass the metadata file instead of a single passage file - meta_file_path = self.index_dir / f"{self.index_prefix}.index.meta.json" - if meta_file_path.exists(): - passages_file = str(meta_file_path) - print(f"INFO: Using metadata file for lazy loading: {passages_file}") - else: - raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}") - + meta_file_path = self.index_dir / f"{self.index_path.name}.meta.json" + if not meta_file_path.exists(): + raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}") zmq_port = kwargs.get("zmq_port", 5557) - server_started = self.embedding_server_manager.start_server( - port=zmq_port, - model_name=self.embedding_model, - passages_file=passages_file, - distance_metric=self.distance_metric - ) - if not server_started: - raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}") - + self._ensure_server_running(str(meta_file_path), port=zmq_port, **kwargs) + if query.dtype != np.float32: query = query.astype(np.float32) - if query.ndim == 1: - query = np.expand_dims(query, axis=0) - if self.distance_metric == "cosine": faiss.normalize_L2(query) - - try: - self._index.hnsw.efSearch = ef - params = faiss.SearchParametersHNSW() - params.zmq_port = kwargs.get("zmq_port", self.zmq_port) - params.efSearch = ef - params.beam_size = 2 # Match research system beam_size - - batch_size = query.shape[0] - distances = np.empty((batch_size, top_k), dtype=np.float32) - labels = np.empty((batch_size, top_k), dtype=np.int64) - - self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params) - - # 🐛 DEBUG: Print raw faiss results before conversion - print(f"🔍 DEBUG HNSW Search Results:") - print(f" Query shape: {query.shape}") - print(f" Top_k: {top_k}") - print(f" Raw faiss indices: {labels[0] if len(labels) > 0 else 'No results'}") - print(f" Raw faiss distances: {distances[0] if len(distances) > 0 else 'No results'}") - - # Convert integer labels to string IDs - string_labels = [] - for batch_idx, batch_labels in enumerate(labels): - batch_string_labels = [] - print(f" Batch {batch_idx} conversion:") - for i, int_label in enumerate(batch_labels): - if int_label in self.label_map: - string_id = self.label_map[int_label] - batch_string_labels.append(string_id) - print(f" faiss[{int_label}] -> passage_id '{string_id}' (distance: {distances[batch_idx][i]:.4f})") - else: - unknown_id = f"unknown_{int_label}" - batch_string_labels.append(unknown_id) - print(f" faiss[{int_label}] -> {unknown_id} (NOT FOUND in label_map!)") - string_labels.append(batch_string_labels) - - return {"labels": string_labels, "distances": distances} - - except Exception as e: - print(f"💥 ERROR: HNSW search failed. Exception: {e}") - raise - - def __del__(self): - if hasattr(self, 'embedding_server_manager'): - self.embedding_server_manager.stop_server() + + params = faiss.SearchParametersHNSW() + params.zmq_port = kwargs.get("zmq_port", 5557) + params.efSearch = kwargs.get("ef", 128) + params.beam_size = 2 + + batch_size = query.shape[0] + distances = np.empty((batch_size, top_k), dtype=np.float32) + labels = np.empty((batch_size, top_k), dtype=np.int64) + + self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params) + + string_labels = [[self.label_map.get(int_label, f"unknown_{int_label}") for int_label in batch_labels] for batch_labels in labels] + + return {"labels": string_labels, "distances": distances} \ No newline at end of file diff --git a/packages/leann-core/src/leann/chat.py b/packages/leann-core/src/leann/chat.py new file mode 100644 index 0000000..5f50dd5 --- /dev/null +++ b/packages/leann-core/src/leann/chat.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +""" +This file contains the chat generation logic for the LEANN project, +supporting different backends like Ollama, Hugging Face Transformers, and a simulation mode. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class LLMInterface(ABC): + """Abstract base class for a generic Language Model (LLM) interface.""" + @abstractmethod + def ask(self, prompt: str, **kwargs) -> str: + """ + Sends a prompt to the LLM and returns the generated text. + + Args: + prompt: The input prompt for the LLM. + **kwargs: Additional keyword arguments for the LLM backend. + + Returns: + The response string from the LLM. + """ + pass + +class OllamaChat(LLMInterface): + """LLM interface for Ollama models.""" + def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"): + self.model = model + self.host = host + logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'") + try: + import requests + # Check if the Ollama server is responsive + if host: + requests.get(host) + except ImportError: + raise ImportError("The 'requests' library is required for Ollama. Please install it with 'pip install requests'.") + except requests.exceptions.ConnectionError: + logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") + raise ConnectionError(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") + + def ask(self, prompt: str, **kwargs) -> str: + import requests + import json + + full_url = f"{self.host}/api/generate" + payload = { + "model": self.model, + "prompt": prompt, + "stream": False, # Keep it simple for now + "options": kwargs + } + logger.info(f"Sending request to Ollama: {payload}") + try: + response = requests.post(full_url, data=json.dumps(payload)) + response.raise_for_status() + + # The response from Ollama can be a stream of JSON objects, handle this + response_parts = response.text.strip().split('\n') + full_response = "" + for part in response_parts: + if part: + json_part = json.loads(part) + full_response += json_part.get("response", "") + if json_part.get("done"): + break + return full_response + except requests.exceptions.RequestException as e: + logger.error(f"Error communicating with Ollama: {e}") + return f"Error: Could not get a response from Ollama. Details: {e}" + +class HFChat(LLMInterface): + """LLM interface for local Hugging Face Transformers models.""" + def __init__(self, model_name: str = "deepseek-ai/deepseek-llm-7b-chat"): + logger.info(f"Initializing HFChat with model='{model_name}'") + try: + from transformers import pipeline + except ImportError: + raise ImportError("The 'transformers' library is required for Hugging Face models. Please install it with 'pip install transformers'.") + + self.pipeline = pipeline("text-generation", model=model_name) + + def ask(self, prompt: str, **kwargs) -> str: + # Sensible defaults for text generation + params = { + "max_length": 500, + "num_return_sequences": 1, + **kwargs + } + logger.info(f"Generating text with Hugging Face model with params: {params}") + results = self.pipeline(prompt, **params) + return results[0]['generated_text'] + +class SimulatedChat(LLMInterface): + """A simple simulated chat for testing and development.""" + def ask(self, prompt: str, **kwargs) -> str: + logger.info("Simulating LLM call...") + print("Prompt sent to LLM (simulation):", prompt[:500] + "...") + return "This is a simulated answer from the LLM based on the retrieved context." + +def get_llm(llm_config: Optional[Dict[str, Any]] = None) -> LLMInterface: + """ + Factory function to get an LLM interface based on configuration. + + Args: + llm_config: A dictionary specifying the LLM type and its parameters. + Example: {"type": "ollama", "model": "llama3"} + {"type": "hf", "model": "distilgpt2"} + None (for simulation mode) + + Returns: + An instance of an LLMInterface subclass. + """ + if llm_config is None: + logger.info("No LLM config provided, defaulting to simulated chat.") + return SimulatedChat() + + llm_type = llm_config.get("type", "simulated") + model = llm_config.get("model") + + logger.info(f"Attempting to create LLM of type='{llm_type}' with model='{model}'") + + if llm_type == "ollama": + return OllamaChat(model=model, host=llm_config.get("host")) + elif llm_type == "hf": + return HFChat(model_name=model) + elif llm_type == "simulated": + return SimulatedChat() + else: + raise ValueError(f"Unknown LLM type: '{llm_type}'") diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py new file mode 100644 index 0000000..6bddfe3 --- /dev/null +++ b/packages/leann-core/src/leann/searcher_base.py @@ -0,0 +1,97 @@ + +import json +import pickle +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, Any, List + +import numpy as np + +from .embedding_server_manager import EmbeddingServerManager +from .interface import LeannBackendSearcherInterface + + +class BaseSearcher(LeannBackendSearcherInterface, ABC): + """ + Abstract base class for Leann searchers, containing common logic for + loading metadata, managing embedding servers, and handling file paths. + """ + + def __init__(self, index_path: str, backend_module_name: str, **kwargs): + """ + Initializes the BaseSearcher. + + Args: + index_path: Path to the Leann index file (e.g., '.../my_index.leann'). + backend_module_name: The specific embedding server module to use + (e.g., 'leann_backend_hnsw.hnsw_embedding_server'). + **kwargs: Additional keyword arguments. + """ + self.index_path = Path(index_path) + self.index_dir = self.index_path.parent + self.meta = kwargs.get("meta", self._load_meta()) + + if not self.meta: + raise ValueError("Searcher requires metadata from .meta.json.") + + self.dimensions = self.meta.get("dimensions") + if not self.dimensions: + raise ValueError("Dimensions not found in Leann metadata.") + + self.embedding_model = self.meta.get("embedding_model") + if not self.embedding_model: + print("WARNING: embedding_model not found in meta.json. Recompute will fail.") + + self.label_map = self._load_label_map() + + self.embedding_server_manager = EmbeddingServerManager( + backend_module_name=backend_module_name + ) + + def _load_meta(self) -> Dict[str, Any]: + """Loads the metadata file associated with the index.""" + # This is the corrected logic for finding the meta file. + meta_path = self.index_dir / f"{self.index_path.name}.meta.json" + if not meta_path.exists(): + raise FileNotFoundError(f"Leann metadata file not found at {meta_path}") + with open(meta_path, 'r', encoding='utf-8') as f: + return json.load(f) + + def _load_label_map(self) -> Dict[int, str]: + """Loads the mapping from integer IDs to string IDs.""" + label_map_file = self.index_dir / "leann.labels.map" + if not label_map_file.exists(): + raise FileNotFoundError(f"Label map file not found: {label_map_file}") + with open(label_map_file, 'rb') as f: + return pickle.load(f) + + def _ensure_server_running(self, passages_source_file: str, port: int, **kwargs) -> None: + """ + Ensures the embedding server is running if recompute is needed. + This is a helper for subclasses. + """ + if not self.embedding_model: + raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.") + + server_started = self.embedding_server_manager.start_server( + port=port, + model_name=self.embedding_model, + passages_file=passages_source_file, + distance_metric=kwargs.get("distance_metric"), + ) + if not server_started: + raise RuntimeError(f"Failed to start embedding server on port {kwargs.get('zmq_port')}") + + @abstractmethod + def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: + """ + Search for the top_k nearest neighbors of the query vector. + Must be implemented by subclasses. + """ + pass + + def __del__(self): + """Ensures the embedding server is stopped when the searcher is destroyed.""" + if hasattr(self, 'embedding_server_manager'): + self.embedding_server_manager.stop_server() +