refactor: passage structure
This commit is contained in:
@@ -74,7 +74,7 @@ def main():
|
||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
||||
print(">>> Basic search results <<<")
|
||||
for i, res in enumerate(results, 1):
|
||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||
|
||||
# --- 3. Recompute search demo ---
|
||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
||||
@@ -107,7 +107,7 @@ def main():
|
||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
||||
print(">>> Recompute search results <<<")
|
||||
for i, res in enumerate(recompute_results, 1):
|
||||
print(f" {i}. ID: {res['id']}, Score: {res['score']:.4f}, Text: '{res['text']}', Metadata: {res['metadata']}")
|
||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
||||
|
||||
# Compare results
|
||||
print(f"\n--- Result comparison ---")
|
||||
@@ -116,8 +116,8 @@ def main():
|
||||
|
||||
print("\nBasic search vs Recompute results:")
|
||||
for i in range(min(len(results), len(recompute_results))):
|
||||
basic_score = results[i]['score']
|
||||
recompute_score = recompute_results[i]['score']
|
||||
basic_score = results[i].score
|
||||
recompute_score = recompute_results[i].score
|
||||
score_diff = abs(basic_score - recompute_score)
|
||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
@@ -11,6 +11,7 @@ import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import pickle
|
||||
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_backend
|
||||
@@ -19,13 +20,13 @@ from leann.interface import (
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendSearcherInterface
|
||||
)
|
||||
from . import _diskannpy as diskannpy
|
||||
|
||||
METRIC_MAP = {
|
||||
"mips": diskannpy.Metric.INNER_PRODUCT,
|
||||
"l2": diskannpy.Metric.L2,
|
||||
"cosine": diskannpy.Metric.COSINE,
|
||||
}
|
||||
def _get_diskann_metrics():
|
||||
from . import _diskannpy as diskannpy
|
||||
return {
|
||||
"mips": diskannpy.Metric.INNER_PRODUCT,
|
||||
"l2": diskannpy.Metric.L2,
|
||||
"cosine": diskannpy.Metric.COSINE,
|
||||
}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def chdir(path):
|
||||
@@ -67,27 +68,8 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
|
||||
"""Generate passages file for recompute mode, mirroring HNSW backend."""
|
||||
try:
|
||||
chunks = kwargs.get('chunks', [])
|
||||
if not chunks:
|
||||
print("INFO: No chunks data provided, skipping passages file generation for DiskANN.")
|
||||
return
|
||||
|
||||
passages_data = {str(node_id): chunk["text"] for node_id, chunk in enumerate(chunks)}
|
||||
|
||||
passages_file = index_dir / f"{index_prefix}.passages.json"
|
||||
with open(passages_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(passages_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: Failed to generate passages file for DiskANN. Exception: {e}")
|
||||
pass
|
||||
|
||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
@@ -102,8 +84,15 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
data_filename = f"{index_prefix}_data.bin"
|
||||
_write_vectors_to_bin(data, index_dir / data_filename)
|
||||
|
||||
# Create label map: integer -> string_id
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, 'wb') as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
||||
METRIC_MAP = _get_diskann_metrics()
|
||||
metric_enum = METRIC_MAP.get(metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
@@ -115,11 +104,11 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
num_threads = build_kwargs.get("num_threads", 8)
|
||||
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
|
||||
codebook_prefix = ""
|
||||
is_recompute = build_kwargs.get("is_recompute", False)
|
||||
|
||||
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy
|
||||
with chdir(index_dir):
|
||||
diskannpy.build_disk_float_index(
|
||||
metric_enum,
|
||||
@@ -134,8 +123,6 @@ class DiskannBuilder(LeannBackendBuilderInterface):
|
||||
codebook_prefix
|
||||
)
|
||||
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
|
||||
if is_recompute:
|
||||
self._generate_passages_file(index_dir, index_prefix, **build_kwargs)
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
|
||||
raise
|
||||
@@ -150,15 +137,6 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
if not self.meta:
|
||||
raise ValueError("DiskannSearcher requires metadata from .meta.json.")
|
||||
|
||||
dimensions = self.meta.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Dimensions not found in Leann metadata.")
|
||||
|
||||
self.distance_metric = self.meta.get("distance_metric", "mips").lower()
|
||||
metric_enum = METRIC_MAP.get(self.distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||
|
||||
self.embedding_model = self.meta.get("embedding_model")
|
||||
if not self.embedding_model:
|
||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
|
||||
@@ -167,11 +145,27 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
self.index_dir = path.parent
|
||||
self.index_prefix = path.stem
|
||||
|
||||
# Load the label map
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
with open(label_map_file, 'rb') as f:
|
||||
self.label_map = pickle.load(f)
|
||||
|
||||
# Extract parameters for DiskANN
|
||||
distance_metric = kwargs.get("distance_metric", "mips").lower()
|
||||
METRIC_MAP = _get_diskann_metrics()
|
||||
metric_enum = METRIC_MAP.get(distance_metric)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{distance_metric}'.")
|
||||
|
||||
num_threads = kwargs.get("num_threads", 8)
|
||||
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
|
||||
self.zmq_port = kwargs.get("zmq_port", 6666)
|
||||
|
||||
try:
|
||||
from . import _diskannpy as diskannpy
|
||||
full_index_prefix = str(self.index_dir / self.index_prefix)
|
||||
self._index = diskannpy.StaticDiskFloatIndex(
|
||||
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, self.zmq_port, "", ""
|
||||
@@ -205,22 +199,18 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if not passages_file:
|
||||
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
||||
if potential_passages_file.exists():
|
||||
passages_file = str(potential_passages_file)
|
||||
print(f"INFO: Automatically found passages file: {passages_file}")
|
||||
|
||||
if not passages_file:
|
||||
raise RuntimeError(
|
||||
f"Recompute mode is enabled, but no passages file was found. "
|
||||
f"A '{self.index_prefix}.passages.json' file should exist in the index directory "
|
||||
f"'{self.index_dir}'. Ensure you build the index with 'recompute=True'."
|
||||
)
|
||||
# Get the passages file path from meta.json
|
||||
if 'passage_sources' in self.meta and self.meta['passage_sources']:
|
||||
passage_source = self.meta['passage_sources'][0]
|
||||
passages_file = passage_source['path']
|
||||
print(f"INFO: Found passages file from metadata: {passages_file}")
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.")
|
||||
|
||||
server_started = self.embedding_server_manager.start_server(
|
||||
port=self.zmq_port,
|
||||
model_name=self.embedding_model,
|
||||
distance_metric=self.distance_metric,
|
||||
distance_metric=kwargs.get("distance_metric", "mips"),
|
||||
passages_file=passages_file
|
||||
)
|
||||
|
||||
@@ -248,11 +238,23 @@ class DiskannSearcher(LeannBackendSearcherInterface):
|
||||
batch_recompute,
|
||||
global_pruning
|
||||
)
|
||||
return {"labels": labels, "distances": distances}
|
||||
|
||||
# Convert integer labels to string IDs
|
||||
string_labels = []
|
||||
for batch_labels in labels:
|
||||
batch_string_labels = []
|
||||
for int_label in batch_labels:
|
||||
if int_label in self.label_map:
|
||||
batch_string_labels.append(self.label_map[int_label])
|
||||
else:
|
||||
batch_string_labels.append(f"unknown_{int_label}")
|
||||
string_labels.append(batch_string_labels)
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
|
||||
batch_size = query.shape[0]
|
||||
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
|
||||
return {"labels": [[f"error_{i}" for i in range(top_k)] for _ in range(batch_size)],
|
||||
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@@ -41,21 +41,48 @@ class SimplePassageLoader:
|
||||
|
||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages from a JSON file
|
||||
Expected format: {"passage_id": "passage_text", ...}
|
||||
Load passages from a JSONL file with label map support
|
||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||
"""
|
||||
if not os.path.exists(passages_file):
|
||||
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
|
||||
return SimplePassageLoader()
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
|
||||
try:
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
passages_data = json.load(f)
|
||||
print(f"Loaded {len(passages_data)} passages from {passages_file}")
|
||||
return SimplePassageLoader(passages_data)
|
||||
except Exception as e:
|
||||
print(f"Error loading passages from {passages_file}: {e}")
|
||||
return SimplePassageLoader()
|
||||
if not os.path.exists(passages_file):
|
||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||
|
||||
if not passages_file.endswith('.jsonl'):
|
||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||
|
||||
# Load label map (int -> string_id)
|
||||
passages_dir = Path(passages_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
label_map = {}
|
||||
if label_map_file.exists():
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
# Load passages by string ID
|
||||
string_id_passages = {}
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
string_id_passages[passage['id']] = passage['text']
|
||||
|
||||
# Create int ID -> text mapping using label map
|
||||
passages_data = {}
|
||||
for int_id, string_id in label_map.items():
|
||||
if string_id in string_id_passages:
|
||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||
else:
|
||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||
return SimplePassageLoader(passages_data)
|
||||
|
||||
def create_embedding_server_thread(
|
||||
zmq_port=5555,
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
@@ -11,6 +11,7 @@ import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import pickle
|
||||
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||
@@ -74,7 +75,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
if self.is_recompute and not self.is_compact:
|
||||
raise ValueError("is_recompute requires is_compact=True for efficiency")
|
||||
|
||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
||||
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
|
||||
"""Build HNSW index using FAISS"""
|
||||
from . import faiss
|
||||
|
||||
@@ -89,6 +90,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
if not data.flags['C_CONTIGUOUS']:
|
||||
data = np.ascontiguousarray(data)
|
||||
|
||||
# Create label map: integer -> string_id
|
||||
label_map = {i: str_id for i, str_id in enumerate(ids)}
|
||||
label_map_file = index_dir / "leann.labels.map"
|
||||
with open(label_map_file, 'wb') as f:
|
||||
pickle.dump(label_map, f)
|
||||
|
||||
metric_str = self.distance_metric.lower()
|
||||
metric_enum = get_metric_map().get(metric_str)
|
||||
if metric_enum is None:
|
||||
@@ -119,9 +126,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
if self.is_compact:
|
||||
self._convert_to_csr(index_file)
|
||||
|
||||
if self.is_recompute:
|
||||
self._generate_passages_file(index_dir, index_prefix, **kwargs)
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
|
||||
raise
|
||||
@@ -155,30 +159,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
|
||||
raise
|
||||
|
||||
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
|
||||
"""Generate passages file for recompute mode"""
|
||||
try:
|
||||
chunks = kwargs.get('chunks', [])
|
||||
if not chunks:
|
||||
print("INFO: No chunks data provided, skipping passages file generation")
|
||||
return
|
||||
|
||||
# Generate node_id to text mapping
|
||||
passages_data = {}
|
||||
for node_id, chunk in enumerate(chunks):
|
||||
passages_data[str(node_id)] = chunk["text"]
|
||||
|
||||
# Save passages file
|
||||
passages_file = index_dir / f"{index_prefix}.passages.json"
|
||||
with open(passages_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(passages_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
|
||||
# Don't raise - this is not critical for index building
|
||||
pass
|
||||
|
||||
class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
|
||||
@@ -282,6 +262,14 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
self.index_dir = path.parent
|
||||
self.index_prefix = path.stem
|
||||
|
||||
# Load the label map
|
||||
label_map_file = self.index_dir / "leann.labels.map"
|
||||
if not label_map_file.exists():
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
with open(label_map_file, 'rb') as f:
|
||||
self.label_map = pickle.load(f)
|
||||
|
||||
index_file = self.index_dir / f"{self.index_prefix}.index"
|
||||
if not index_file.exists():
|
||||
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
||||
@@ -336,12 +324,13 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
passages_file = kwargs.get("passages_file")
|
||||
if not passages_file:
|
||||
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
|
||||
if potential_passages_file.exists():
|
||||
passages_file = str(potential_passages_file)
|
||||
print(f"INFO: Automatically found passages file: {passages_file}")
|
||||
# Get the passages file path from meta.json
|
||||
if 'passage_sources' in self.meta and self.meta['passage_sources']:
|
||||
passage_source = self.meta['passage_sources'][0]
|
||||
passages_file = passage_source['path']
|
||||
print(f"INFO: Found passages file from metadata: {passages_file}")
|
||||
else:
|
||||
raise RuntimeError(f"FATAL: Index is pruned but no passages file found.")
|
||||
raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
|
||||
|
||||
zmq_port = kwargs.get("zmq_port", 5557)
|
||||
server_started = self.embedding_server_manager.start_server(
|
||||
@@ -372,7 +361,18 @@ class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
|
||||
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
|
||||
|
||||
return {"labels": labels, "distances": distances}
|
||||
# Convert integer labels to string IDs
|
||||
string_labels = []
|
||||
for batch_labels in labels:
|
||||
batch_string_labels = []
|
||||
for int_label in batch_labels:
|
||||
if int_label in self.label_map:
|
||||
batch_string_labels.append(self.label_map[int_label])
|
||||
else:
|
||||
batch_string_labels.append(f"unknown_{int_label}")
|
||||
string_labels.append(batch_string_labels)
|
||||
|
||||
return {"labels": string_labels, "distances": distances}
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
|
||||
|
||||
@@ -58,21 +58,46 @@ class SimplePassageLoader:
|
||||
|
||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages from a JSON file
|
||||
Expected format: {"passage_id": "passage_text", ...}
|
||||
Load passages from a JSONL file with label map support
|
||||
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
|
||||
"""
|
||||
if not os.path.exists(passages_file):
|
||||
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
|
||||
return SimplePassageLoader()
|
||||
raise FileNotFoundError(f"Passages file {passages_file} not found.")
|
||||
|
||||
try:
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
passages_data = json.load(f)
|
||||
print(f"Loaded {len(passages_data)} passages from {passages_file}")
|
||||
return SimplePassageLoader(passages_data)
|
||||
except Exception as e:
|
||||
print(f"Error loading passages from {passages_file}: {e}")
|
||||
return SimplePassageLoader()
|
||||
if not passages_file.endswith('.jsonl'):
|
||||
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
|
||||
|
||||
# Load label map (int -> string_id)
|
||||
passages_dir = Path(passages_file).parent
|
||||
label_map_file = passages_dir / "leann.labels.map"
|
||||
|
||||
label_map = {}
|
||||
if label_map_file.exists():
|
||||
import pickle
|
||||
with open(label_map_file, 'rb') as f:
|
||||
label_map = pickle.load(f)
|
||||
print(f"Loaded label map with {len(label_map)} entries")
|
||||
else:
|
||||
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
||||
|
||||
# Load passages by string ID
|
||||
string_id_passages = {}
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
string_id_passages[passage['id']] = passage['text']
|
||||
|
||||
# Create int ID -> text mapping using label map
|
||||
passages_data = {}
|
||||
for int_id, string_id in label_map.items():
|
||||
if string_id in string_id_passages:
|
||||
passages_data[str(int_id)] = string_id_passages[string_id]
|
||||
else:
|
||||
print(f"WARNING: String ID {string_id} from label map not found in passages")
|
||||
|
||||
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
|
||||
return SimplePassageLoader(passages_data)
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
|
||||
@@ -7,6 +7,8 @@ import json
|
||||
from pathlib import Path
|
||||
import openai
|
||||
from dataclasses import dataclass, field
|
||||
import uuid
|
||||
import pickle
|
||||
|
||||
# --- Helper Functions for Embeddings ---
|
||||
|
||||
@@ -56,11 +58,45 @@ def _get_embedding_dimensions(model_name: str) -> int:
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Represents a single search result."""
|
||||
id: int
|
||||
id: str
|
||||
score: float
|
||||
text: str
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PassageManager:
|
||||
"""Manages passage data and lazy loading from JSONL files."""
|
||||
|
||||
def __init__(self, passage_sources: List[Dict[str, Any]]):
|
||||
self.offset_maps = {}
|
||||
self.passage_files = {}
|
||||
|
||||
for source in passage_sources:
|
||||
if source["type"] == "jsonl":
|
||||
passage_file = source["path"]
|
||||
index_file = source["index_path"]
|
||||
|
||||
if not os.path.exists(index_file):
|
||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||
|
||||
with open(index_file, 'rb') as f:
|
||||
offset_map = pickle.load(f)
|
||||
|
||||
self.offset_maps[passage_file] = offset_map
|
||||
self.passage_files[passage_file] = passage_file
|
||||
|
||||
def get_passage(self, passage_id: str) -> Dict[str, Any]:
|
||||
"""Lazy load a passage by ID."""
|
||||
for passage_file, offset_map in self.offset_maps.items():
|
||||
if passage_id in offset_map:
|
||||
offset = offset_map[passage_id]
|
||||
with open(passage_file, 'r', encoding='utf-8') as f:
|
||||
f.seek(offset)
|
||||
line = f.readline()
|
||||
return json.loads(line)
|
||||
|
||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||
|
||||
# --- Core Classes ---
|
||||
|
||||
class LeannBuilder:
|
||||
@@ -82,7 +118,26 @@ class LeannBuilder:
|
||||
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
|
||||
|
||||
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
|
||||
self.chunks.append({"text": text, "metadata": metadata or {}})
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
# Check if ID is provided in metadata
|
||||
passage_id = metadata.get('id')
|
||||
if passage_id is None:
|
||||
passage_id = str(uuid.uuid4())
|
||||
else:
|
||||
# Validate uniqueness
|
||||
existing_ids = {chunk['id'] for chunk in self.chunks}
|
||||
if passage_id in existing_ids:
|
||||
raise ValueError(f"Duplicate passage ID: {passage_id}")
|
||||
|
||||
# Store the definitive ID with the chunk
|
||||
chunk_data = {
|
||||
"id": passage_id,
|
||||
"text": text,
|
||||
"metadata": metadata
|
||||
}
|
||||
self.chunks.append(chunk_data)
|
||||
|
||||
def build_index(self, index_path: str):
|
||||
if not self.chunks:
|
||||
@@ -92,28 +147,65 @@ class LeannBuilder:
|
||||
self.dimensions = _get_embedding_dimensions(self.embedding_model)
|
||||
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
|
||||
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_name = path.name
|
||||
|
||||
# Ensure the directory exists
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create the passages.jsonl file and offset index
|
||||
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||
|
||||
offset_map = {}
|
||||
|
||||
with open(passages_file, 'w', encoding='utf-8') as f:
|
||||
for chunk in self.chunks:
|
||||
offset = f.tell()
|
||||
passage_data = {
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk["metadata"]
|
||||
}
|
||||
json.dump(passage_data, f, ensure_ascii=False)
|
||||
f.write('\n')
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
# Save the offset map
|
||||
with open(offset_file, 'wb') as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
# Compute embeddings
|
||||
texts_to_embed = [c["text"] for c in self.chunks]
|
||||
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
|
||||
|
||||
|
||||
# Extract string IDs for the backend
|
||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||
|
||||
# Build the vector index
|
||||
current_backend_kwargs = self.backend_kwargs.copy()
|
||||
current_backend_kwargs['dimensions'] = self.dimensions
|
||||
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
|
||||
|
||||
build_kwargs = current_backend_kwargs.copy()
|
||||
build_kwargs['chunks'] = self.chunks
|
||||
builder_instance.build(embeddings, index_path, **build_kwargs)
|
||||
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
|
||||
|
||||
index_dir = Path(index_path).parent
|
||||
leann_meta_path = index_dir / f"{Path(index_path).name}.meta.json"
|
||||
# Create the lightweight meta.json file
|
||||
leann_meta_path = index_dir / f"{index_name}.meta.json"
|
||||
|
||||
meta_data = {
|
||||
"version": "0.1.0",
|
||||
"version": "1.0",
|
||||
"backend_name": self.backend_name,
|
||||
"embedding_model": self.embedding_model,
|
||||
"dimensions": self.dimensions,
|
||||
"backend_kwargs": self.backend_kwargs,
|
||||
"num_chunks": len(self.chunks),
|
||||
"chunks": self.chunks,
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": str(passages_file),
|
||||
"index_path": str(offset_file)
|
||||
}
|
||||
]
|
||||
}
|
||||
with open(leann_meta_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
@@ -136,14 +228,16 @@ class LeannSearcher:
|
||||
backend_name = self.meta_data['backend_name']
|
||||
self.embedding_model = self.meta_data['embedding_model']
|
||||
|
||||
# Initialize the passage manager
|
||||
passage_sources = self.meta_data.get('passage_sources', [])
|
||||
self.passage_manager = PassageManager(passage_sources)
|
||||
|
||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||
if backend_factory is None:
|
||||
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
|
||||
|
||||
final_kwargs = self.meta_data.get("backend_kwargs", {})
|
||||
final_kwargs.update(backend_kwargs)
|
||||
if 'dimensions' not in final_kwargs:
|
||||
final_kwargs['dimensions'] = self.meta_data.get('dimensions')
|
||||
final_kwargs = backend_kwargs.copy()
|
||||
final_kwargs['meta'] = self.meta_data
|
||||
|
||||
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
|
||||
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
|
||||
@@ -155,15 +249,17 @@ class LeannSearcher:
|
||||
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
|
||||
|
||||
enriched_results = []
|
||||
for label, dist in zip(results['labels'][0], results['distances'][0]):
|
||||
if label < len(self.meta_data['chunks']):
|
||||
chunk_info = self.meta_data['chunks'][label]
|
||||
for string_id, dist in zip(results['labels'][0], results['distances'][0]):
|
||||
try:
|
||||
passage_data = self.passage_manager.get_passage(string_id)
|
||||
enriched_results.append(SearchResult(
|
||||
id=label,
|
||||
id=string_id,
|
||||
score=dist,
|
||||
text=chunk_info['text'],
|
||||
metadata=chunk_info.get('metadata', {})
|
||||
text=passage_data['text'],
|
||||
metadata=passage_data.get('metadata', {})
|
||||
))
|
||||
except KeyError:
|
||||
print(f"WARNING: Passage ID '{string_id}' not found in passage files")
|
||||
return enriched_results
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user