diff --git a/examples/document_search.py b/examples/document_search.py index 72d5bdc..b007bb3 100644 --- a/examples/document_search.py +++ b/examples/document_search.py @@ -74,7 +74,7 @@ def main(): print(f"⏱️ Basic search time: {basic_time:.3f} seconds") print(">>> Basic search results <<<") for i, res in enumerate(results, 1): - print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") + print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}") # --- 3. Recompute search demo --- print(f"\n[PHASE 3] Recompute search using embedding server...") @@ -107,7 +107,7 @@ def main(): print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds") print(">>> Recompute search results <<<") for i, res in enumerate(recompute_results, 1): - print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") + print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}") # Compare results print(f"\n--- Result comparison ---") @@ -116,8 +116,8 @@ def main(): print("\nBasic search vs Recompute results:") for i in range(min(len(results), len(recompute_results))): - basic_score = results[i]['score'] - recompute_score = recompute_results[i]['score'] + basic_score = results[i].score + recompute_score = recompute_results[i].score score_diff = abs(basic_score - recompute_score) print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}") diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index 515560a..23ea745 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -3,7 +3,7 @@ import os import json import struct from pathlib import Path -from typing import Dict, Any +from typing import Dict, Any, List import contextlib import threading import time @@ -11,6 +11,7 @@ import atexit import socket import subprocess import sys +import pickle from leann.embedding_server_manager import EmbeddingServerManager from leann.registry import register_backend @@ -19,13 +20,13 @@ from leann.interface import ( LeannBackendBuilderInterface, LeannBackendSearcherInterface ) -from . import _diskannpy as diskannpy - -METRIC_MAP = { - "mips": diskannpy.Metric.INNER_PRODUCT, - "l2": diskannpy.Metric.L2, - "cosine": diskannpy.Metric.COSINE, -} +def _get_diskann_metrics(): + from . import _diskannpy as diskannpy + return { + "mips": diskannpy.Metric.INNER_PRODUCT, + "l2": diskannpy.Metric.L2, + "cosine": diskannpy.Metric.COSINE, + } @contextlib.contextmanager def chdir(path): @@ -67,27 +68,8 @@ class DiskannBuilder(LeannBackendBuilderInterface): def __init__(self, **kwargs): self.build_params = kwargs - def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs): - """Generate passages file for recompute mode, mirroring HNSW backend.""" - try: - chunks = kwargs.get('chunks', []) - if not chunks: - print("INFO: No chunks data provided, skipping passages file generation for DiskANN.") - return - - passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)} - - passages_file = index_dir / f"{index_prefix}.passages.json" - with open(passages_file, 'w', encoding='utf-8') as f: - json.dump(passages_data, f, ensure_ascii=False, indent=2) - - print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)") - - except Exception as e: - print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}") - pass - def build(self, data: np.ndarray, index_path: str, **kwargs): + def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): path = Path(index_path) index_dir = path.parent index_prefix = path.stem @@ -102,8 +84,15 @@ class DiskannBuilder(LeannBackendBuilderInterface): data_filename = f"{index_prefix}_data.bin" _write_vectors_to_bin(data, index_dir / data_filename) + # Create label map: integer -> string_id + label_map = {i: str_id for i, str_id in enumerate(ids)} + label_map_file = index_dir / "leann.labels.map" + with open(label_map_file, 'wb') as f: + pickle.dump(label_map, f) + build_kwargs = {**self.build_params, **kwargs} metric_str = build_kwargs.get("distance_metric", "mips").lower() + METRIC_MAP = _get_diskann_metrics() metric_enum = METRIC_MAP.get(metric_str) if metric_enum is None: raise ValueError(f"Unsupported distance_metric '{metric_str}'.") @@ -115,11 +104,11 @@ class DiskannBuilder(LeannBackendBuilderInterface): num_threads = build_kwargs.get("num_threads", 8) pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0) codebook_prefix = "" - is_recompute = build_kwargs.get("is_recompute", False) print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...") try: + from . import _diskannpy as diskannpy with chdir(index_dir): diskannpy.build_disk_float_index( metric_enum, @@ -134,8 +123,6 @@ class DiskannBuilder(LeannBackendBuilderInterface): codebook_prefix ) print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'") - if is_recompute: - self._generate_passages_file(index_dir, index_prefix, **build_kwargs) except Exception as e: print(f"💥 ERROR: DiskANN index build failed. Exception: {e}") raise @@ -150,15 +137,6 @@ class DiskannSearcher(LeannBackendSearcherInterface): if not self.meta: raise ValueError("DiskannSearcher requires metadata from .meta.json.") - dimensions = self.meta.get("dimensions") - if not dimensions: - raise ValueError("Dimensions not found in Leann metadata.") - - self.distance_metric = self.meta.get("distance_metric", "mips").lower() - metric_enum = METRIC_MAP.get(self.distance_metric) - if metric_enum is None: - raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") - self.embedding_model = self.meta.get("embedding_model") if not self.embedding_model: print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") @@ -167,11 +145,27 @@ class DiskannSearcher(LeannBackendSearcherInterface): self.index_dir = path.parent self.index_prefix = path.stem + # Load the label map + label_map_file = self.index_dir / "leann.labels.map" + if not label_map_file.exists(): + raise FileNotFoundError(f"Label map file not found: {label_map_file}") + + with open(label_map_file, 'rb') as f: + self.label_map = pickle.load(f) + + # Extract parameters for DiskANN + distance_metric = kwargs.get("distance_metric", "mips").lower() + METRIC_MAP = _get_diskann_metrics() + metric_enum = METRIC_MAP.get(distance_metric) + if metric_enum is None: + raise ValueError(f"Unsupported distance_metric '{distance_metric}'.") + num_threads = kwargs.get("num_threads", 8) num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) self.zmq_port = kwargs.get("zmq_port", 6666) try: + from . import _diskannpy as diskannpy full_index_prefix = str(self.index_dir / self.index_prefix) self._index = diskannpy.StaticDiskFloatIndex( metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", "" @@ -205,22 +199,18 @@ class DiskannSearcher(LeannBackendSearcherInterface): passages_file = kwargs.get("passages_file") if not passages_file: - potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json" - if potential_passages_file.exists(): - passages_file = str(potential_passages_file) - print(f"INFO: Automatically found passages file: {passages_file}") - - if not passages_file: - raise RuntimeError( - f"Recompute mode is enabled, but no passages file was found. " - f"A '{self.index_prefix}.passages.json' file should exist in the index directory " - f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'." - ) + # Get the passages file path from meta.json + if 'passage_sources' in self.meta and self.meta['passage_sources']: + passage_source = self.meta['passage_sources'][0] + passages_file = passage_source['path'] + print(f"INFO: Found passages file from metadata: {passages_file}") + else: + raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.") server_started = self.embedding_server_manager.start_server( port=self.zmq_port, model_name=self.embedding_model, - distance_metric=self.distance_metric, + distance_metric=kwargs.get("distance_metric", "mips"), passages_file=passages_file ) @@ -248,11 +238,23 @@ class DiskannSearcher(LeannBackendSearcherInterface): batch_recompute, global_pruning ) - return {"labels": labels, "distances": distances} + + # Convert integer labels to string IDs + string_labels = [] + for batch_labels in labels: + batch_string_labels = [] + for int_label in batch_labels: + if int_label in self.label_map: + batch_string_labels.append(self.label_map[int_label]) + else: + batch_string_labels.append(f"unknown_{int_label}") + string_labels.append(batch_string_labels) + + return {"labels": string_labels, "distances": distances} except Exception as e: print(f"💥 ERROR: DiskANN search failed. Exception: {e}") batch_size = query.shape[0] - return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64), + return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)], "distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)} def __del__(self): diff --git a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py index a5e4329..ee2d4b2 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/embedding_server.py @@ -41,21 +41,48 @@ class SimplePassageLoader: def load_passages_from_file(passages_file: str) -> SimplePassageLoader: """ - Load passages from a JSON file - Expected format: {"passage_id": "passage_text", ...} + Load passages from a JSONL file with label map support + Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line) """ - if not os.path.exists(passages_file): - print(f"Warning: Passages file {passages_file} not found. Using empty loader.") - return SimplePassageLoader() + from pathlib import Path + import pickle - try: - with open(passages_file, 'r', encoding='utf-8') as f: - passages_data = json.load(f) - print(f"Loaded {len(passages_data)} passages from {passages_file}") - return SimplePassageLoader(passages_data) - except Exception as e: - print(f"Error loading passages from {passages_file}: {e}") - return SimplePassageLoader() + if not os.path.exists(passages_file): + raise FileNotFoundError(f"Passages file {passages_file} not found.") + + if not passages_file.endswith('.jsonl'): + raise ValueError(f"Expected .jsonl file format, got: {passages_file}") + + # Load label map (int -> string_id) + passages_dir = Path(passages_file).parent + label_map_file = passages_dir / "leann.labels.map" + + label_map = {} + if label_map_file.exists(): + with open(label_map_file, 'rb') as f: + label_map = pickle.load(f) + print(f"Loaded label map with {len(label_map)} entries") + else: + raise FileNotFoundError(f"Label map file not found: {label_map_file}") + + # Load passages by string ID + string_id_passages = {} + with open(passages_file, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + passage = json.loads(line) + string_id_passages[passage['id']] = passage['text'] + + # Create int ID -> text mapping using label map + passages_data = {} + for int_id, string_id in label_map.items(): + if string_id in string_id_passages: + passages_data[str(int_id)] = string_id_passages[string_id] + else: + print(f"WARNING: String ID {string_id} from label map not found in passages") + + print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map") + return SimplePassageLoader(passages_data) def create_embedding_server_thread( zmq_port=5555, diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index dd678a7..bbf9cec 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -3,7 +3,7 @@ import os import json import struct from pathlib import Path -from typing import Dict, Any +from typing import Dict, Any, List import contextlib import threading import time @@ -11,6 +11,7 @@ import atexit import socket import subprocess import sys +import pickle from leann.embedding_server_manager import EmbeddingServerManager from .convert_to_csr import convert_hnsw_graph_to_csr @@ -74,7 +75,7 @@ class HNSWBuilder(LeannBackendBuilderInterface): if self.is_recompute and not self.is_compact: raise ValueError("is_recompute requires is_compact=True for efficiency") - def build(self, data: np.ndarray, index_path: str, **kwargs): + def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs): """Build HNSW index using FAISS""" from . import faiss @@ -89,6 +90,12 @@ class HNSWBuilder(LeannBackendBuilderInterface): if not data.flags['C_CONTIGUOUS']: data = np.ascontiguousarray(data) + # Create label map: integer -> string_id + label_map = {i: str_id for i, str_id in enumerate(ids)} + label_map_file = index_dir / "leann.labels.map" + with open(label_map_file, 'wb') as f: + pickle.dump(label_map, f) + metric_str = self.distance_metric.lower() metric_enum = get_metric_map().get(metric_str) if metric_enum is None: @@ -119,9 +126,6 @@ class HNSWBuilder(LeannBackendBuilderInterface): if self.is_compact: self._convert_to_csr(index_file) - if self.is_recompute: - self._generate_passages_file(index_dir, index_prefix, **kwargs) - except Exception as e: print(f"💥 ERROR: HNSW index build failed. Exception: {e}") raise @@ -155,30 +159,6 @@ class HNSWBuilder(LeannBackendBuilderInterface): print(f"💥 ERROR: CSR conversion failed. Exception: {e}") raise - def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs): - """Generate passages file for recompute mode""" - try: - chunks = kwargs.get('chunks', []) - if not chunks: - print("INFO: No chunks data provided, skipping passages file generation") - return - - # Generate node_id to text mapping - passages_data = {} - for node_id, chunk in enumerate(chunks): - passages_data[str(node_id)] = chunk["text"] - - # Save passages file - passages_file = index_dir / f"{index_prefix}.passages.json" - with open(passages_file, 'w', encoding='utf-8') as f: - json.dump(passages_data, f, ensure_ascii=False, indent=2) - - print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)") - - except Exception as e: - print(f"💥 ERROR: Failed to generate passages file. Exception: {e}") - # Don't raise - this is not critical for index building - pass class HNSWSearcher(LeannBackendSearcherInterface): def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]: @@ -282,6 +262,14 @@ class HNSWSearcher(LeannBackendSearcherInterface): self.index_dir = path.parent self.index_prefix = path.stem + # Load the label map + label_map_file = self.index_dir / "leann.labels.map" + if not label_map_file.exists(): + raise FileNotFoundError(f"Label map file not found: {label_map_file}") + + with open(label_map_file, 'rb') as f: + self.label_map = pickle.load(f) + index_file = self.index_dir / f"{self.index_prefix}.index" if not index_file.exists(): raise FileNotFoundError(f"HNSW index file not found at {index_file}") @@ -336,12 +324,13 @@ class HNSWSearcher(LeannBackendSearcherInterface): passages_file = kwargs.get("passages_file") if not passages_file: - potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json" - if potential_passages_file.exists(): - passages_file = str(potential_passages_file) - print(f"INFO: Automatically found passages file: {passages_file}") + # Get the passages file path from meta.json + if 'passage_sources' in self.meta and self.meta['passage_sources']: + passage_source = self.meta['passage_sources'][0] + passages_file = passage_source['path'] + print(f"INFO: Found passages file from metadata: {passages_file}") else: - raise RuntimeError(f"FATAL: Index is pruned but no passages file found.") + raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.") zmq_port = kwargs.get("zmq_port", 5557) server_started = self.embedding_server_manager.start_server( @@ -372,7 +361,18 @@ class HNSWSearcher(LeannBackendSearcherInterface): self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params) - return {"labels": labels, "distances": distances} + # Convert integer labels to string IDs + string_labels = [] + for batch_labels in labels: + batch_string_labels = [] + for int_label in batch_labels: + if int_label in self.label_map: + batch_string_labels.append(self.label_map[int_label]) + else: + batch_string_labels.append(f"unknown_{int_label}") + string_labels.append(batch_string_labels) + + return {"labels": string_labels, "distances": distances} except Exception as e: print(f"💥 ERROR: HNSW search failed. Exception: {e}") diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index c548cae..c8cd358 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -58,21 +58,46 @@ class SimplePassageLoader: def load_passages_from_file(passages_file: str) -> SimplePassageLoader: """ - Load passages from a JSON file - Expected format: {"passage_id": "passage_text", ...} + Load passages from a JSONL file with label map support + Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line) """ if not os.path.exists(passages_file): - print(f"Warning: Passages file {passages_file} not found. Using empty loader.") - return SimplePassageLoader() + raise FileNotFoundError(f"Passages file {passages_file} not found.") - try: - with open(passages_file, 'r', encoding='utf-8') as f: - passages_data = json.load(f) - print(f"Loaded {len(passages_data)} passages from {passages_file}") - return SimplePassageLoader(passages_data) - except Exception as e: - print(f"Error loading passages from {passages_file}: {e}") - return SimplePassageLoader() + if not passages_file.endswith('.jsonl'): + raise ValueError(f"Expected .jsonl file format, got: {passages_file}") + + # Load label map (int -> string_id) + passages_dir = Path(passages_file).parent + label_map_file = passages_dir / "leann.labels.map" + + label_map = {} + if label_map_file.exists(): + import pickle + with open(label_map_file, 'rb') as f: + label_map = pickle.load(f) + print(f"Loaded label map with {len(label_map)} entries") + else: + raise FileNotFoundError(f"Label map file not found: {label_map_file}") + + # Load passages by string ID + string_id_passages = {} + with open(passages_file, 'r', encoding='utf-8') as f: + for line in f: + if line.strip(): + passage = json.loads(line) + string_id_passages[passage['id']] = passage['text'] + + # Create int ID -> text mapping using label map + passages_data = {} + for int_id, string_id in label_map.items(): + if string_id in string_id_passages: + passages_data[str(int_id)] = string_id_passages[string_id] + else: + print(f"WARNING: String ID {string_id} from label map not found in passages") + + print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map") + return SimplePassageLoader(passages_data) def create_hnsw_embedding_server( passages_file: Optional[str] = None, diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index ed6af41..d146bed 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -7,6 +7,8 @@ import json from pathlib import Path import openai from dataclasses import dataclass, field +import uuid +import pickle # --- Helper Functions for Embeddings --- @@ -56,11 +58,45 @@ def _get_embedding_dimensions(model_name: str) -> int: @dataclass class SearchResult: """Represents a single search result.""" - id: int + id: str score: float text: str metadata: Dict[str, Any] = field(default_factory=dict) + +class PassageManager: + """Manages passage data and lazy loading from JSONL files.""" + + def __init__(self, passage_sources: List[Dict[str, Any]]): + self.offset_maps = {} + self.passage_files = {} + + for source in passage_sources: + if source["type"] == "jsonl": + passage_file = source["path"] + index_file = source["index_path"] + + if not os.path.exists(index_file): + raise FileNotFoundError(f"Passage index file not found: {index_file}") + + with open(index_file, 'rb') as f: + offset_map = pickle.load(f) + + self.offset_maps[passage_file] = offset_map + self.passage_files[passage_file] = passage_file + + def get_passage(self, passage_id: str) -> Dict[str, Any]: + """Lazy load a passage by ID.""" + for passage_file, offset_map in self.offset_maps.items(): + if passage_id in offset_map: + offset = offset_map[passage_id] + with open(passage_file, 'r', encoding='utf-8') as f: + f.seek(offset) + line = f.readline() + return json.loads(line) + + raise KeyError(f"Passage ID not found: {passage_id}") + # --- Core Classes --- class LeannBuilder: @@ -82,7 +118,26 @@ class LeannBuilder: print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.") def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): - self.chunks.append({"text": text, "metadata": metadata or {}}) + if metadata is None: + metadata = {} + + # Check if ID is provided in metadata + passage_id = metadata.get('id') + if passage_id is None: + passage_id = str(uuid.uuid4()) + else: + # Validate uniqueness + existing_ids = {chunk['id'] for chunk in self.chunks} + if passage_id in existing_ids: + raise ValueError(f"Duplicate passage ID: {passage_id}") + + # Store the definitive ID with the chunk + chunk_data = { + "id": passage_id, + "text": text, + "metadata": metadata + } + self.chunks.append(chunk_data) def build_index(self, index_path: str): if not self.chunks: @@ -92,28 +147,65 @@ class LeannBuilder: self.dimensions = _get_embedding_dimensions(self.embedding_model) print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}") + path = Path(index_path) + index_dir = path.parent + index_name = path.name + + # Ensure the directory exists + index_dir.mkdir(parents=True, exist_ok=True) + + # Create the passages.jsonl file and offset index + passages_file = index_dir / f"{index_name}.passages.jsonl" + offset_file = index_dir / f"{index_name}.passages.idx" + + offset_map = {} + + with open(passages_file, 'w', encoding='utf-8') as f: + for chunk in self.chunks: + offset = f.tell() + passage_data = { + "id": chunk["id"], + "text": chunk["text"], + "metadata": chunk["metadata"] + } + json.dump(passage_data, f, ensure_ascii=False) + f.write('\n') + offset_map[chunk["id"]] = offset + + # Save the offset map + with open(offset_file, 'wb') as f: + pickle.dump(offset_map, f) + + # Compute embeddings texts_to_embed = [c["text"] for c in self.chunks] embeddings = _compute_embeddings(texts_to_embed, self.embedding_model) - + + # Extract string IDs for the backend + string_ids = [chunk["id"] for chunk in self.chunks] + + # Build the vector index current_backend_kwargs = self.backend_kwargs.copy() current_backend_kwargs['dimensions'] = self.dimensions builder_instance = self.backend_factory.builder(**current_backend_kwargs) - build_kwargs = current_backend_kwargs.copy() - build_kwargs['chunks'] = self.chunks - builder_instance.build(embeddings, index_path, **build_kwargs) + builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs) - index_dir = Path(index_path).parent - leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json" + # Create the lightweight meta.json file + leann_meta_path = index_dir / f"{index_name}.meta.json" meta_data = { - "version": "0.1.0", + "version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model, "dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs, - "num_chunks": len(self.chunks), - "chunks": self.chunks, + "passage_sources": [ + { + "type": "jsonl", + "path": str(passages_file), + "index_path": str(offset_file) + } + ] } with open(leann_meta_path, 'w', encoding='utf-8') as f: json.dump(meta_data, f, indent=2) @@ -136,14 +228,16 @@ class LeannSearcher: backend_name = self.meta_data['backend_name'] self.embedding_model = self.meta_data['embedding_model'] + # Initialize the passage manager + passage_sources = self.meta_data.get('passage_sources', []) + self.passage_manager = PassageManager(passage_sources) + backend_factory = BACKEND_REGISTRY.get(backend_name) if backend_factory is None: raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.") - final_kwargs = self.meta_data.get("backend_kwargs", {}) - final_kwargs.update(backend_kwargs) - if 'dimensions' not in final_kwargs: - final_kwargs['dimensions'] = self.meta_data.get('dimensions') + final_kwargs = backend_kwargs.copy() + final_kwargs['meta'] = self.meta_data self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.") @@ -155,15 +249,17 @@ class LeannSearcher: results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) enriched_results = [] - for label, dist in zip(results['labels'][0], results['distances'][0]): - if label < len(self.meta_data['chunks']): - chunk_info = self.meta_data['chunks'][label] + for string_id, dist in zip(results['labels'][0], results['distances'][0]): + try: + passage_data = self.passage_manager.get_passage(string_id) enriched_results.append(SearchResult( - id=label, + id=string_id, score=dist, - text=chunk_info['text'], - metadata=chunk_info.get('metadata', {}) + text=passage_data['text'], + metadata=passage_data.get('metadata', {}) )) + except KeyError: + print(f"WARNING: Passage ID '{string_id}' not found in passage files") return enriched_results