refactor: passage structure

This commit is contained in:
Andy Lee
2025-07-06 21:48:38 +00:00
parent 5611f708e9
commit 16705fc44a
6 changed files with 289 additions and 139 deletions

View File

@@ -74,7 +74,7 @@ def main():
print(f"⏱️ Basic search time: {basic_time:.3f} seconds") print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
print(">>> Basic search results <<<") print(">>> Basic search results <<<")
for i, res in enumerate(results, 1): for i, res in enumerate(results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# --- 3. Recompute search demo --- # --- 3. Recompute search demo ---
print(f"\n[PHASE 3] Recompute search using embedding server...") print(f"\n[PHASE 3] Recompute search using embedding server...")
@@ -107,7 +107,7 @@ def main():
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds") print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
print(">>> Recompute search results <<<") print(">>> Recompute search results <<<")
for i, res in enumerate(recompute_results, 1): for i, res in enumerate(recompute_results, 1):
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}") print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
# Compare results # Compare results
print(f"\n--- Result comparison ---") print(f"\n--- Result comparison ---")
@@ -116,8 +116,8 @@ def main():
print("\nBasic search vs Recompute results:") print("\nBasic search vs Recompute results:")
for i in range(min(len(results), len(recompute_results))): for i in range(min(len(results), len(recompute_results))):
basic_score = results[i]['score'] basic_score = results[i].score
recompute_score = recompute_results[i]['score'] recompute_score = recompute_results[i].score
score_diff = abs(basic_score - recompute_score) score_diff = abs(basic_score - recompute_score)
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}") print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any, List
import contextlib import contextlib
import threading import threading
import time import time
@@ -11,6 +11,7 @@ import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager from leann.embedding_server_manager import EmbeddingServerManager
from leann.registry import register_backend from leann.registry import register_backend
@@ -19,13 +20,13 @@ from leann.interface import (
LeannBackendBuilderInterface, LeannBackendBuilderInterface,
LeannBackendSearcherInterface LeannBackendSearcherInterface
) )
from . import _diskannpy as diskannpy def _get_diskann_metrics():
from . import _diskannpy as diskannpy
METRIC_MAP = { return {
"mips": diskannpy.Metric.INNER_PRODUCT, "mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2, "l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE, "cosine": diskannpy.Metric.COSINE,
} }
@contextlib.contextmanager @contextlib.contextmanager
def chdir(path): def chdir(path):
@@ -67,27 +68,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.build_params = kwargs self.build_params = kwargs
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode, mirroring HNSW backend."""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation for DiskANN.")
return
passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)} def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}")
pass
def build(self, data: np.ndarray, index_path: str, **kwargs):
path = Path(index_path) path = Path(index_path)
index_dir = path.parent index_dir = path.parent
index_prefix = path.stem index_prefix = path.stem
@@ -102,8 +84,15 @@ class DiskannBuilder(LeannBackendBuilderInterface):
data_filename = f"{index_prefix}_data.bin" data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename) _write_vectors_to_bin(data, index_dir / data_filename)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
build_kwargs = {**self.build_params, **kwargs} build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower() metric_str = build_kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(metric_str) metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.") raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
@@ -115,11 +104,11 @@ class DiskannBuilder(LeannBackendBuilderInterface):
num_threads = build_kwargs.get("num_threads", 8) num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0) pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = "" codebook_prefix = ""
is_recompute = build_kwargs.get("is_recompute", False)
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...") print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try: try:
from . import _diskannpy as diskannpy
with chdir(index_dir): with chdir(index_dir):
diskannpy.build_disk_float_index( diskannpy.build_disk_float_index(
metric_enum, metric_enum,
@@ -134,8 +123,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
codebook_prefix codebook_prefix
) )
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'") print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
if is_recompute:
self._generate_passages_file(index_dir, index_prefix, **build_kwargs)
except Exception as e: except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}") print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise raise
@@ -150,15 +137,6 @@ class DiskannSearcher(LeannBackendSearcherInterface):
if not self.meta: if not self.meta:
raise ValueError("DiskannSearcher requires metadata from .meta.json.") raise ValueError("DiskannSearcher requires metadata from .meta.json.")
dimensions = self.meta.get("dimensions")
if not dimensions:
raise ValueError("Dimensions not found in Leann metadata.")
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(self.distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.embedding_model = self.meta.get("embedding_model") self.embedding_model = self.meta.get("embedding_model")
if not self.embedding_model: if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
@@ -167,11 +145,27 @@ class DiskannSearcher(LeannBackendSearcherInterface):
self.index_dir = path.parent self.index_dir = path.parent
self.index_prefix = path.stem self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
# Extract parameters for DiskANN
distance_metric = kwargs.get("distance_metric", "mips").lower()
METRIC_MAP = _get_diskann_metrics()
metric_enum = METRIC_MAP.get(distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
num_threads = kwargs.get("num_threads", 8) num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
self.zmq_port = kwargs.get("zmq_port", 6666) self.zmq_port = kwargs.get("zmq_port", 6666)
try: try:
from . import _diskannpy as diskannpy
full_index_prefix = str(self.index_dir / self.index_prefix) full_index_prefix = str(self.index_dir / self.index_prefix)
self._index = diskannpy.StaticDiskFloatIndex( self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", "" metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
@@ -205,22 +199,18 @@ class DiskannSearcher(LeannBackendSearcherInterface):
passages_file = kwargs.get("passages_file") passages_file = kwargs.get("passages_file")
if not passages_file: if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json" # Get the passages file path from meta.json
if potential_passages_file.exists(): if 'passage_sources' in self.meta and self.meta['passage_sources']:
passages_file = str(potential_passages_file) passage_source = self.meta['passage_sources'][0]
print(f"INFO: Automatically found passages file: {passages_file}") passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
if not passages_file: else:
raise RuntimeError( raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.")
f"Recompute mode is enabled, but no passages file was found. "
f"A '{self.index_prefix}.passages.json' file should exist in the index directory "
f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'."
)
server_started = self.embedding_server_manager.start_server( server_started = self.embedding_server_manager.start_server(
port=self.zmq_port, port=self.zmq_port,
model_name=self.embedding_model, model_name=self.embedding_model,
distance_metric=self.distance_metric, distance_metric=kwargs.get("distance_metric", "mips"),
passages_file=passages_file passages_file=passages_file
) )
@@ -248,11 +238,23 @@ class DiskannSearcher(LeannBackendSearcherInterface):
batch_recompute, batch_recompute,
global_pruning global_pruning
) )
return {"labels": labels, "distances": distances}
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e: except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}") print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0] batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64), return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)],
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)} "distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self): def __del__(self):

View File

@@ -41,21 +41,48 @@ class SimplePassageLoader:
def load_passages_from_file(passages_file: str) -> SimplePassageLoader: def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
""" """
Load passages from a JSON file Load passages from a JSONL file with label map support
Expected format: {"passage_id": "passage_text", ...} Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
""" """
if not os.path.exists(passages_file): from pathlib import Path
print(f"Warning: Passages file {passages_file} not found. Using empty loader.") import pickle
return SimplePassageLoader()
try: if not os.path.exists(passages_file):
with open(passages_file, 'r', encoding='utf-8') as f: raise FileNotFoundError(f"Passages file {passages_file} not found.")
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}") if not passages_file.endswith('.jsonl'):
return SimplePassageLoader(passages_data) raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
except Exception as e:
print(f"Error loading passages from {passages_file}: {e}") # Load label map (int -> string_id)
return SimplePassageLoader() passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_embedding_server_thread( def create_embedding_server_thread(
zmq_port=5555, zmq_port=5555,

View File

@@ -3,7 +3,7 @@ import os
import json import json
import struct import struct
from pathlib import Path from pathlib import Path
from typing import Dict, Any from typing import Dict, Any, List
import contextlib import contextlib
import threading import threading
import time import time
@@ -11,6 +11,7 @@ import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager from leann.embedding_server_manager import EmbeddingServerManager
from .convert_to_csr import convert_hnsw_graph_to_csr from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -74,7 +75,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_recompute and not self.is_compact: if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency") raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, index_path: str, **kwargs): def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""Build HNSW index using FAISS""" """Build HNSW index using FAISS"""
from . import faiss from . import faiss
@@ -89,6 +90,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if not data.flags['C_CONTIGUOUS']: if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data) data = np.ascontiguousarray(data)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
metric_str = self.distance_metric.lower() metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str) metric_enum = get_metric_map().get(metric_str)
if metric_enum is None: if metric_enum is None:
@@ -119,9 +126,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_compact: if self.is_compact:
self._convert_to_csr(index_file) self._convert_to_csr(index_file)
if self.is_recompute:
self._generate_passages_file(index_dir, index_prefix, **kwargs)
except Exception as e: except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}") print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise raise
@@ -155,30 +159,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
print(f"💥 ERROR: CSR conversion failed. Exception: {e}") print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise raise
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode"""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation")
return
# Generate node_id to text mapping
passages_data = {}
for node_id, chunk in enumerate(chunks):
passages_data[str(node_id)] = chunk["text"]
# Save passages file
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
# Don't raise - this is not critical for index building
pass
class HNSWSearcher(LeannBackendSearcherInterface): class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]: def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
@@ -282,6 +262,14 @@ class HNSWSearcher(LeannBackendSearcherInterface):
self.index_dir = path.parent self.index_dir = path.parent
self.index_prefix = path.stem self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
index_file = self.index_dir / f"{self.index_prefix}.index" index_file = self.index_dir / f"{self.index_prefix}.index"
if not index_file.exists(): if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}") raise FileNotFoundError(f"HNSW index file not found at {index_file}")
@@ -336,12 +324,13 @@ class HNSWSearcher(LeannBackendSearcherInterface):
passages_file = kwargs.get("passages_file") passages_file = kwargs.get("passages_file")
if not passages_file: if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json" # Get the passages file path from meta.json
if potential_passages_file.exists(): if 'passage_sources' in self.meta and self.meta['passage_sources']:
passages_file = str(potential_passages_file) passage_source = self.meta['passage_sources'][0]
print(f"INFO: Automatically found passages file: {passages_file}") passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else: else:
raise RuntimeError(f"FATAL: Index is pruned but no passages file found.") raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
zmq_port = kwargs.get("zmq_port", 5557) zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server( server_started = self.embedding_server_manager.start_server(
@@ -372,7 +361,18 @@ class HNSWSearcher(LeannBackendSearcherInterface):
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params) self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
return {"labels": labels, "distances": distances} # Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e: except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}") print(f"💥 ERROR: HNSW search failed. Exception: {e}")

View File

@@ -58,21 +58,46 @@ class SimplePassageLoader:
def load_passages_from_file(passages_file: str) -> SimplePassageLoader: def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
""" """
Load passages from a JSON file Load passages from a JSONL file with label map support
Expected format: {"passage_id": "passage_text", ...} Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
""" """
if not os.path.exists(passages_file): if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.") raise FileNotFoundError(f"Passages file {passages_file} not found.")
return SimplePassageLoader()
try: if not passages_file.endswith('.jsonl'):
with open(passages_file, 'r', encoding='utf-8') as f: raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}") # Load label map (int -> string_id)
return SimplePassageLoader(passages_data) passages_dir = Path(passages_file).parent
except Exception as e: label_map_file = passages_dir / "leann.labels.map"
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader() label_map = {}
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,

View File

@@ -7,6 +7,8 @@ import json
from pathlib import Path from pathlib import Path
import openai import openai
from dataclasses import dataclass, field from dataclasses import dataclass, field
import uuid
import pickle
# --- Helper Functions for Embeddings --- # --- Helper Functions for Embeddings ---
@@ -56,11 +58,45 @@ def _get_embedding_dimensions(model_name: str) -> int:
@dataclass @dataclass
class SearchResult: class SearchResult:
"""Represents a single search result.""" """Represents a single search result."""
id: int id: str
score: float score: float
text: str text: str
metadata: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager:
"""Manages passage data and lazy loading from JSONL files."""
def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not os.path.exists(index_file):
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, 'rb') as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
def get_passage(self, passage_id: str) -> Dict[str, Any]:
"""Lazy load a passage by ID."""
for passage_file, offset_map in self.offset_maps.items():
if passage_id in offset_map:
offset = offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
line = f.readline()
return json.loads(line)
raise KeyError(f"Passage ID not found: {passage_id}")
# --- Core Classes --- # --- Core Classes ---
class LeannBuilder: class LeannBuilder:
@@ -82,7 +118,26 @@ class LeannBuilder:
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.") print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None): def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
self.chunks.append({"text": text, "metadata": metadata or {}}) if metadata is None:
metadata = {}
# Check if ID is provided in metadata
passage_id = metadata.get('id')
if passage_id is None:
passage_id = str(uuid.uuid4())
else:
# Validate uniqueness
existing_ids = {chunk['id'] for chunk in self.chunks}
if passage_id in existing_ids:
raise ValueError(f"Duplicate passage ID: {passage_id}")
# Store the definitive ID with the chunk
chunk_data = {
"id": passage_id,
"text": text,
"metadata": metadata
}
self.chunks.append(chunk_data)
def build_index(self, index_path: str): def build_index(self, index_path: str):
if not self.chunks: if not self.chunks:
@@ -92,28 +147,65 @@ class LeannBuilder:
self.dimensions = _get_embedding_dimensions(self.embedding_model) self.dimensions = _get_embedding_dimensions(self.embedding_model)
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}") print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
# Ensure the directory exists
index_dir.mkdir(parents=True, exist_ok=True)
# Create the passages.jsonl file and offset index
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
offset_map = {}
with open(passages_file, 'w', encoding='utf-8') as f:
for chunk in self.chunks:
offset = f.tell()
passage_data = {
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"]
}
json.dump(passage_data, f, ensure_ascii=False)
f.write('\n')
offset_map[chunk["id"]] = offset
# Save the offset map
with open(offset_file, 'wb') as f:
pickle.dump(offset_map, f)
# Compute embeddings
texts_to_embed = [c["text"] for c in self.chunks] texts_to_embed = [c["text"] for c in self.chunks]
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model) embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
# Extract string IDs for the backend
string_ids = [chunk["id"] for chunk in self.chunks]
# Build the vector index
current_backend_kwargs = self.backend_kwargs.copy() current_backend_kwargs = self.backend_kwargs.copy()
current_backend_kwargs['dimensions'] = self.dimensions current_backend_kwargs['dimensions'] = self.dimensions
builder_instance = self.backend_factory.builder(**current_backend_kwargs) builder_instance = self.backend_factory.builder(**current_backend_kwargs)
build_kwargs = current_backend_kwargs.copy() builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
build_kwargs['chunks'] = self.chunks
builder_instance.build(embeddings, index_path, **build_kwargs)
index_dir = Path(index_path).parent # Create the lightweight meta.json file
leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json" leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = { meta_data = {
"version": "0.1.0", "version": "1.0",
"backend_name": self.backend_name, "backend_name": self.backend_name,
"embedding_model": self.embedding_model, "embedding_model": self.embedding_model,
"dimensions": self.dimensions, "dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs, "backend_kwargs": self.backend_kwargs,
"num_chunks": len(self.chunks), "passage_sources": [
"chunks": self.chunks, {
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file)
}
]
} }
with open(leann_meta_path, 'w', encoding='utf-8') as f: with open(leann_meta_path, 'w', encoding='utf-8') as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
@@ -136,14 +228,16 @@ class LeannSearcher:
backend_name = self.meta_data['backend_name'] backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model'] self.embedding_model = self.meta_data['embedding_model']
# Initialize the passage manager
passage_sources = self.meta_data.get('passage_sources', [])
self.passage_manager = PassageManager(passage_sources)
backend_factory = BACKEND_REGISTRY.get(backend_name) backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None: if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.") raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
final_kwargs = self.meta_data.get("backend_kwargs", {}) final_kwargs = backend_kwargs.copy()
final_kwargs.update(backend_kwargs) final_kwargs['meta'] = self.meta_data
if 'dimensions' not in final_kwargs:
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs) self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.") print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
@@ -155,15 +249,17 @@ class LeannSearcher:
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs) results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
enriched_results = [] enriched_results = []
for label, dist in zip(results['labels'][0], results['distances'][0]): for string_id, dist in zip(results['labels'][0], results['distances'][0]):
if label < len(self.meta_data['chunks']): try:
chunk_info = self.meta_data['chunks'][label] passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult( enriched_results.append(SearchResult(
id=label, id=string_id,
score=dist, score=dist,
text=chunk_info['text'], text=passage_data['text'],
metadata=chunk_info.get('metadata', {}) metadata=passage_data.get('metadata', {})
)) ))
except KeyError:
print(f"WARNING: Passage ID '{string_id}' not found in passage files")
return enriched_results return enriched_results