feat: reproducible research datas, rpj_wiki & dpr

This commit is contained in:
Andy Lee
2025-07-11 02:58:04 +00:00
parent 16705fc44a
commit 8bffb1e5b8
8 changed files with 493 additions and 402 deletions

157
examples/run_evaluation.py Normal file
View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""
This script runs a recall evaluation on a given LEANN index.
It correctly compares results by fetching the text content for both the new search
results and the golden standard results, making the comparison robust to ID changes.
"""
import json
import argparse
import time
from pathlib import Path
import sys
import numpy as np
from typing import List, Dict, Any
import glob
import pickle
# Add project root to path to allow importing from leann
project_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(project_root))
from leann.api import LeannSearcher
# --- Configuration ---
NQ_QUERIES_FILE = Path("/opt/dlami/nvme/scaling_out/examples/nq_open.jsonl")
# Ground truth files for different datasets
GROUND_TRUTH_FILES = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/indices/rpj_wiki/facebook/contriever-msmarco/flat_results_nq_k3.json",
"dpr": "/opt/dlami/nvme/scaling_out/indices/dpr/facebook/contriever-msmarco/flat_results_nq_k3.json"
}
# Old passages for different datasets
OLD_PASSAGES_GLOBS = {
"rpj_wiki": "/opt/dlami/nvme/scaling_out/passages/rpj_wiki/8-shards/raw_passages-*-of-8.pkl.jsonl",
"dpr": "/opt/dlami/nvme/scaling_out/passages/dpr/1-shards/raw_passages-*-of-1.pkl.jsonl"
}
# --- Helper Class to Load Original Passages ---
class OldPassageLoader:
"""A simplified version of the old LazyPassages class to fetch golden results by ID."""
def __init__(self, passages_glob: str):
self.jsonl_paths = sorted(glob.glob(passages_glob))
self.offsets = {}
self.fps = [open(p, "r", encoding="utf-8") for p in self.jsonl_paths]
print("Building offset map for original passages...")
for i, shard_path_str in enumerate(self.jsonl_paths):
old_idx_path = Path(shard_path_str.replace(".jsonl", ".idx"))
if not old_idx_path.exists(): continue
with open(old_idx_path, 'rb') as f:
shard_offsets = pickle.load(f)
for pid, offset in shard_offsets.items():
self.offsets[str(pid)] = (i, offset)
print("Offset map for original passages is ready.")
def get_passage_by_id(self, pid: str) -> Dict[str, Any]:
pid = str(pid)
if pid not in self.offsets:
raise ValueError(f"Passage ID {pid} not found in offsets")
file_idx, offset = self.offsets[pid]
fp = self.fps[file_idx]
fp.seek(offset)
return json.loads(fp.readline())
def __del__(self):
for fp in self.fps:
fp.close()
def load_queries(file_path: Path) -> List[str]:
queries = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
queries.append(data['query'])
return queries
def main():
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
parser.add_argument("index_path", type=str, help="Path to the LEANN index to evaluate.")
parser.add_argument("--num-queries", type=int, default=10, help="Number of queries to evaluate.")
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
parser.add_argument("--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW.")
args = parser.parse_args()
print(f"--- Final, Correct Recall Evaluation (efSearch={args.ef_search}) ---")
# Detect dataset type from index path
index_path_str = str(args.index_path)
if "rpj_wiki" in index_path_str:
dataset_type = "rpj_wiki"
elif "dpr" in index_path_str:
dataset_type = "dpr"
else:
print("WARNING: Unknown dataset type, defaulting to rpj_wiki")
dataset_type = "rpj_wiki"
print(f"INFO: Detected dataset type: {dataset_type}")
try:
searcher = LeannSearcher(args.index_path)
queries = load_queries(NQ_QUERIES_FILE)
golden_results_file = GROUND_TRUTH_FILES[dataset_type]
old_passages_glob = OLD_PASSAGES_GLOBS[dataset_type]
print(f"INFO: Using ground truth file: {golden_results_file}")
print(f"INFO: Using old passages glob: {old_passages_glob}")
with open(golden_results_file, 'r') as f:
golden_results_data = json.load(f)
old_passage_loader = OldPassageLoader(old_passages_glob)
num_eval_queries = min(args.num_queries, len(queries))
queries = queries[:num_eval_queries]
print(f"\nRunning evaluation on {num_eval_queries} queries...")
recall_scores = []
search_times = []
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
search_times.append(time.time() - start_time)
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}
golden_ids = golden_results_data["indices"][i][:args.top_k]
golden_texts = {old_passage_loader.get_passage_by_id(str(gid))['text'] for gid in golden_ids}
overlap = len(new_texts & golden_texts)
recall = overlap / len(golden_texts) if golden_texts else 0
recall_scores.append(recall)
print("\n--- EVALUATION RESULTS ---")
print(f"Query: {queries[i]}")
print(f"New Results: {new_texts}")
print(f"Golden Results: {golden_texts}")
print(f"Overlap: {overlap}")
print(f"Recall: {recall}")
print(f"Search Time: {search_times[-1]:.4f}s")
print(f"--------------------------------")
avg_recall = np.mean(recall_scores) if recall_scores else 0
avg_time = np.mean(search_times) if search_times else 0
print(f"\n🎉 --- Evaluation Complete ---")
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
print(f"Avg. Search Time: {avg_time:.4f}s")
except Exception as e:
print(f"\n❌ An error occurred during evaluation: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -141,9 +141,9 @@ class DiskannSearcher(LeannBackendSearcherInterface):
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
self.index_path = Path(index_path)
self.index_dir = self.index_path.parent
self.index_prefix = self.index_path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
@@ -199,13 +199,13 @@ class DiskannSearcher(LeannBackendSearcherInterface):
passages_file = kwargs.get("passages_file")
if not passages_file:
# Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
# Pass the metadata file instead of a single passage file
meta_file_path = self.index_path.parent / f"{self.index_path.name}.meta.json"
if meta_file_path.exists():
passages_file = str(meta_file_path)
print(f"INFO: Using metadata file for lazy loading: {passages_file}")
else:
raise RuntimeError(f"FATAL: Recompute mode enabled but no passage_sources found in metadata.")
raise RuntimeError(f"FATAL: Recompute mode enabled but metadata file not found: {meta_file_path}")
server_started = self.embedding_server_manager.start_server(
port=self.zmq_port,

View File

@@ -39,6 +39,71 @@ class SimplePassageLoader:
def __len__(self) -> int:
return len(self.passages_data)
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""
Load passages using metadata file with PassageManager for lazy loading
"""
# Load metadata to get passage sources
with open(meta_file, 'r') as f:
meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports
import sys
from pathlib import Path
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
raise RuntimeError(f"FATAL: Empty text for ID {int_id} -> {string_id}")
else:
raise RuntimeError(f"FATAL: ID {int_id} not found in label_map")
except Exception as e:
raise RuntimeError(f"FATAL: Exception getting passage {passage_id}: {e}")
def __len__(self) -> int:
return len(self.label_map)
return LazyPassageLoader(passage_manager, label_map)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
@@ -140,7 +205,21 @@ def create_embedding_server_thread(
# Load passages from file if provided
if passages_file and os.path.exists(passages_file):
passages = load_passages_from_file(passages_file)
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
from pathlib import Path
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = load_passages_from_file(passages_file)
else:
print("WARNING: No passages file provided or file not found. Using an empty passage loader.")
passages = SimplePassageLoader()

View File

@@ -1,7 +1,6 @@
import numpy as np
import os
import json
import struct
from pathlib import Path
from typing import Dict, Any, List
import contextlib
@@ -161,83 +160,19 @@ class HNSWBuilder(LeannBackendBuilderInterface):
class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
def _get_index_storage_status_from_meta(self) -> tuple[bool, bool]:
"""
Robustly determines the index's storage status by parsing the file.
Get storage status from metadata with sensible defaults.
Returns:
A tuple (is_compact, is_pruned).
"""
if not index_file.exists():
return False, False
# Check if metadata has these flags
is_compact = self.meta.get('is_compact', True) # Default to compact (CSR format)
is_pruned = self.meta.get('is_pruned', True) # Default to pruned (embeddings removed)
with open(index_file, 'rb') as f:
try:
def read_struct(fmt):
size = struct.calcsize(fmt)
data = f.read(size)
if len(data) != size:
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.")
return struct.unpack(fmt, data)[0]
def skip_vector(element_size):
count = read_struct('<Q')
f.seek(count * element_size, 1)
# 1. Read up to the compact flag
read_struct('<I'); read_struct('<i'); read_struct('<q');
read_struct('<q'); read_struct('<q'); read_struct('<?')
metric_type = read_struct('<i')
if metric_type > 1: read_struct('<f')
skip_vector(8); skip_vector(4); skip_vector(4)
# 2. Check if there's a compact flag byte
# Try to read the compact flag, but handle both old and new formats
pos_before_compact = f.tell()
try:
is_compact = read_struct('<?')
print(f"INFO: Detected is_compact flag as: {is_compact}")
except (EOFError, struct.error):
# Old format without compact flag - assume non-compact
f.seek(pos_before_compact)
is_compact = False
print(f"INFO: No compact flag found, assuming is_compact=False")
# 3. Read storage FourCC to determine if pruned
is_pruned = False
try:
if is_compact:
# For compact, we need to skip pointers and scalars to get to the storage FourCC
skip_vector(8) # level_ptr
skip_vector(8) # node_offsets
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
storage_fourcc = read_struct('<I')
else:
# For non-compact, we need to read the flag probe, then skip offsets and neighbors
pos_before_probe = f.tell()
flag_byte = f.read(1)
if not (flag_byte and flag_byte == b'\x00'):
f.seek(pos_before_probe)
skip_vector(8); skip_vector(4) # offsets, neighbors
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
# Now we are at the storage. The entire rest is storage blob.
storage_fourcc = struct.unpack('<I', f.read(4))[0]
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
if storage_fourcc == NULL_INDEX_FOURCC:
is_pruned = True
except (EOFError, struct.error):
# Cannot determine pruning status, assume not pruned
pass
print(f"INFO: Detected is_pruned as: {is_pruned}")
return is_compact, is_pruned
except (EOFError, struct.error) as e:
print(f"WARNING: Could not parse index file to detect format: {e}. Assuming standard, not pruned.")
return False, False
print(f"INFO: Storage status from metadata: is_compact={is_compact}, is_pruned={is_pruned}")
return is_compact, is_pruned
def __init__(self, index_path: str, **kwargs):
from . import faiss
@@ -258,6 +193,10 @@ class HNSWSearcher(LeannBackendSearcherInterface):
if not self.embedding_model:
print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.")
# Check for embedding model override (not allowed)
if 'embedding_model' in kwargs and kwargs['embedding_model'] != self.embedding_model:
raise ValueError(f"Embedding model override not allowed. Index uses '{self.embedding_model}', but got '{kwargs['embedding_model']}'")
path = Path(index_path)
self.index_dir = path.parent
self.index_prefix = path.stem
@@ -274,7 +213,14 @@ class HNSWSearcher(LeannBackendSearcherInterface):
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
# Get storage status from metadata with user overrides
self.is_compact, self.is_pruned = self._get_index_storage_status_from_meta()
# Allow override of storage parameters via kwargs
if 'is_compact' in kwargs:
self.is_compact = kwargs['is_compact']
if 'is_pruned' in kwargs:
self.is_pruned = kwargs['is_pruned']
# Validate configuration constraints
if not self.is_compact and kwargs.get("is_skip_neighbors", False):
@@ -315,7 +261,7 @@ class HNSWSearcher(LeannBackendSearcherInterface):
"""Search using HNSW index with optional recompute functionality"""
from . import faiss
ef = kwargs.get("ef", 200)
ef = kwargs.get("ef", 128)
if self.is_pruned:
print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.")
@@ -324,13 +270,13 @@ class HNSWSearcher(LeannBackendSearcherInterface):
passages_file = kwargs.get("passages_file")
if not passages_file:
# Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
# Pass the metadata file instead of a single passage file
meta_file_path = self.index_dir / f"{self.index_prefix}.index.meta.json"
if meta_file_path.exists():
passages_file = str(meta_file_path)
print(f"INFO: Using metadata file for lazy loading: {passages_file}")
else:
raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
raise RuntimeError(f"FATAL: Index is pruned but metadata file not found: {meta_file_path}")
zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server(
@@ -351,9 +297,11 @@ class HNSWSearcher(LeannBackendSearcherInterface):
faiss.normalize_L2(query)
try:
self._index.hnsw.efSearch = ef
params = faiss.SearchParametersHNSW()
params.efSearch = ef
params.zmq_port = kwargs.get("zmq_port", self.zmq_port)
params.efSearch = ef
params.beam_size = 2 # Match research system beam_size
batch_size = query.shape[0]
distances = np.empty((batch_size, top_k), dtype=np.float32)
@@ -361,15 +309,27 @@ class HNSWSearcher(LeannBackendSearcherInterface):
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
# 🐛 DEBUG: Print raw faiss results before conversion
print(f"🔍 DEBUG HNSW Search Results:")
print(f" Query shape: {query.shape}")
print(f" Top_k: {top_k}")
print(f" Raw faiss indices: {labels[0] if len(labels) > 0 else 'No results'}")
print(f" Raw faiss distances: {distances[0] if len(distances) > 0 else 'No results'}")
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
for batch_idx, batch_labels in enumerate(labels):
batch_string_labels = []
for int_label in batch_labels:
print(f" Batch {batch_idx} conversion:")
for i, int_label in enumerate(batch_labels):
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
string_id = self.label_map[int_label]
batch_string_labels.append(string_id)
print(f" faiss[{int_label}] -> passage_id '{string_id}' (distance: {distances[batch_idx][i]:.4f})")
else:
batch_string_labels.append(f"unknown_{int_label}")
unknown_id = f"unknown_{int_label}"
batch_string_labels.append(unknown_id)
print(f" faiss[{int_label}] -> {unknown_id} (NOT FOUND in label_map!)")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}

View File

@@ -56,22 +56,33 @@ class SimplePassageLoader:
def __len__(self) -> int:
return len(self.passages_data)
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
Load passages using metadata file with PassageManager for lazy loading
"""
if not os.path.exists(passages_file):
raise FileNotFoundError(f"Passages file {passages_file} not found.")
# Load metadata to get passage sources
with open(meta_file, 'r') as f:
meta = json.load(f)
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Import PassageManager dynamically to avoid circular imports
import sys
import importlib.util
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
finally:
sys.path.pop(0)
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
@@ -80,24 +91,38 @@ def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
print(f"Initialized lazy passage loading for {len(label_map)} passages")
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
int_id = int(passage_id)
if int_id in self.label_map:
string_id = self.label_map[int_id]
passage_data = self.passage_manager.get_passage(string_id)
if passage_data and passage_data.get("text"):
return {"text": passage_data["text"]}
else:
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}")
return {"text": ""}
else:
print(f"DEBUG: ID {int_id} not found in label_map")
return {"text": ""}
except Exception as e:
print(f"DEBUG: Exception getting passage {passage_id}: {e}")
return {"text": ""}
def __len__(self) -> int:
return len(self.label_map)
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
return LazyPassageLoader(passage_manager, label_map)
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
@@ -183,7 +208,20 @@ def create_hnsw_embedding_server(
passages = SimplePassageLoader(passages_data)
print(f"Using provided passages data: {len(passages)} passages")
elif passages_file:
passages = load_passages_from_file(passages_file)
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
passages = load_passages_from_metadata(passages_file)
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
meta_files = list(passages_dir.glob("*.meta.json"))
if meta_files:
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = SimplePassageLoader() # Use empty loader to avoid massive warnings
else:
passages = SimplePassageLoader()
print("No passages provided, using empty loader")
@@ -252,6 +290,11 @@ def create_hnsw_embedding_server(
_is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch)
# Validate no empty texts
for i, text in enumerate(texts_batch):
if not text or text.strip() == "":
raise RuntimeError(f"FATAL: Empty text at batch index {i}, ID: {ids_batch[i] if i < len(ids_batch) else 'unknown'}")
# E5 model preprocessing
if _is_e5_model:
processed_texts_batch = [f"passage: {text}" for text in texts_batch]
@@ -398,14 +441,12 @@ def create_hnsw_embedding_server(
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
try:
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
else:
txt = txtinfo["text"]
except (KeyError, IndexError):
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
print(f"DEBUG: Looking up passage ID {nid}")
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} returned empty text")
txt = txtinfo["text"]
print(f"DEBUG: Found text for ID {nid}, length: {len(txt)}")
texts.append(txt)
lookup_timer.print_elapsed()

View File

@@ -1,4 +1,4 @@
# 文件: packages/leann-backend-hnsw/pyproject.toml
# packages/leann-backend-hnsw/pyproject.toml
[build-system]
requires = ["scikit-build-core>=0.10", "numpy", "swig"]
@@ -10,7 +10,6 @@ version = "0.1.0"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = ["leann-core==0.1.0", "numpy"]
# 回归到最标准的 scikit-build-core 配置
[tool.scikit-build]
wheel.packages = ["leann_backend_hnsw"]
editable.mode = "redirect"

View File

@@ -1,345 +1,185 @@
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
from typing import List, Dict, Any, Optional
import numpy as np
import os
#!/usr/bin/env python3
"""
This file contains the core API for the LEANN project, now definitively updated
with the correct, original embedding logic from the user's reference code.
"""
import json
import pickle
import numpy as np
from pathlib import Path
import openai
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
import uuid
import pickle
# --- Helper Functions for Embeddings ---
from .registry import BACKEND_REGISTRY
from .interface import LeannBackendFactoryInterface
def _get_openai_client():
"""Initializes and returns an OpenAI client, ensuring the API key is set."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable not set, which is required for OpenAI models.")
return openai.OpenAI(api_key=api_key)
# --- The Correct, Verified Embedding Logic from old_code.py ---
def _is_openai_model(model_name: str) -> bool:
"""Checks if the model is likely an OpenAI embedding model."""
# This is a simple check, can be improved with a more robust list.
return "ada" in model_name or "babbage" in model_name or model_name.startswith("text-embedding-")
def _compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings for a list of text chunks using either SentenceTransformers or OpenAI."""
if _is_openai_model(model_name):
print(f"INFO: Computing embeddings for {len(chunks)} chunks using OpenAI model '{model_name}'...")
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=chunks)
embeddings = [item.embedding for item in response.data]
else:
def compute_embeddings(chunks: List[str], model_name: str) -> np.ndarray:
"""Computes embeddings using sentence-transformers for consistent results."""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
print(f"INFO: Computing embeddings for {len(chunks)} chunks using SentenceTransformer model '{model_name}'...")
embeddings = model.encode(chunks, show_progress_bar=True)
except ImportError as e:
raise RuntimeError(
f"sentence-transformers not available. Install with: pip install sentence-transformers"
) from e
return np.asarray(embeddings, dtype=np.float32)
def _get_embedding_dimensions(model_name: str) -> int:
"""Gets the embedding dimensions for a given model."""
print(f"INFO: Calculating dimensions for model '{model_name}'...")
if _is_openai_model(model_name):
client = _get_openai_client()
response = client.embeddings.create(model=model_name, input=["dummy text"])
return len(response.data[0].embedding)
else:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_name)
dimension = model.get_sentence_embedding_dimension()
if dimension is None:
raise ValueError(f"Model '{model_name}' does not have a valid embedding dimension.")
return dimension
# Load model using sentence-transformers
model = SentenceTransformer(model_name)
# Generate embeddings
embeddings = model.encode(chunks, convert_to_numpy=True, show_progress_bar=True, batch_size=64)
return embeddings
# --- Core API Classes (Restored and Unchanged) ---
@dataclass
class SearchResult:
"""Represents a single search result."""
id: str
score: float
text: str
metadata: Dict[str, Any] = field(default_factory=dict)
class PassageManager:
"""Manages passage data and lazy loading from JSONL files."""
def __init__(self, passage_sources: List[Dict[str, Any]]):
self.offset_maps = {}
self.passage_files = {}
self.global_offset_map = {} # Combined map for fast lookup
for source in passage_sources:
if source["type"] == "jsonl":
passage_file = source["path"]
index_file = source["index_path"]
if not os.path.exists(index_file):
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, 'rb') as f:
offset_map = pickle.load(f)
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
self.offset_maps[passage_file] = offset_map
self.passage_files[passage_file] = passage_file
# Build global map for O(1) lookup
for passage_id, offset in offset_map.items():
self.global_offset_map[passage_id] = (passage_file, offset)
def get_passage(self, passage_id: str) -> Dict[str, Any]:
"""Lazy load a passage by ID."""
for passage_file, offset_map in self.offset_maps.items():
if passage_id in offset_map:
offset = offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
line = f.readline()
return json.loads(line)
if passage_id in self.global_offset_map:
passage_file, offset = self.global_offset_map[passage_id]
with open(passage_file, 'r', encoding='utf-8') as f:
f.seek(offset)
return json.loads(f.readline())
raise KeyError(f"Passage ID not found: {passage_id}")
# --- Core Classes ---
class LeannBuilder:
"""
The builder is responsible for building the index, it will compute the embeddings and then build the index.
It will also save the metadata of the index.
"""
def __init__(self, backend_name: str, embedding_model: str = "sentence-transformers/all-mpnet-base-v2", dimensions: Optional[int] = None, **backend_kwargs):
def __init__(self, backend_name: str, embedding_model: str = "facebook/contriever-msmarco", dimensions: Optional[int] = None, **backend_kwargs):
self.backend_name = backend_name
backend_factory: LeannBackendFactoryInterface | None = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found or not registered.")
self.backend_factory = backend_factory
self.embedding_model = embedding_model
self.dimensions = dimensions
self.backend_kwargs = backend_kwargs
self.chunks: List[Dict[str, Any]] = []
print(f"INFO: LeannBuilder initialized with '{backend_name}' backend.")
def add_text(self, text: str, metadata: Optional[Dict[str, Any]] = None):
if metadata is None:
metadata = {}
# Check if ID is provided in metadata
passage_id = metadata.get('id')
if passage_id is None:
passage_id = str(uuid.uuid4())
else:
# Validate uniqueness
existing_ids = {chunk['id'] for chunk in self.chunks}
if passage_id in existing_ids:
raise ValueError(f"Duplicate passage ID: {passage_id}")
# Store the definitive ID with the chunk
chunk_data = {
"id": passage_id,
"text": text,
"metadata": metadata
}
if metadata is None: metadata = {}
passage_id = metadata.get('id', str(uuid.uuid4()))
chunk_data = {"id": passage_id, "text": text, "metadata": metadata}
self.chunks.append(chunk_data)
def build_index(self, index_path: str):
if not self.chunks:
raise ValueError("No chunks added. Use add_text() first.")
if self.dimensions is None:
self.dimensions = _get_embedding_dimensions(self.embedding_model)
print(f"INFO: Auto-detected dimensions for '{self.embedding_model}': {self.dimensions}")
if not self.chunks: raise ValueError("No chunks added.")
if self.dimensions is None: self.dimensions = len(compute_embeddings(["dummy"], self.embedding_model)[0])
path = Path(index_path)
index_dir = path.parent
index_name = path.name
# Ensure the directory exists
index_dir.mkdir(parents=True, exist_ok=True)
# Create the passages.jsonl file and offset index
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
offset_map = {}
with open(passages_file, 'w', encoding='utf-8') as f:
for chunk in self.chunks:
offset = f.tell()
passage_data = {
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk["metadata"]
}
json.dump(passage_data, f, ensure_ascii=False)
json.dump({"id": chunk["id"], "text": chunk["text"], "metadata": chunk["metadata"]}, f, ensure_ascii=False)
f.write('\n')
offset_map[chunk["id"]] = offset
# Save the offset map
with open(offset_file, 'wb') as f:
pickle.dump(offset_map, f)
# Compute embeddings
with open(offset_file, 'wb') as f: pickle.dump(offset_map, f)
texts_to_embed = [c["text"] for c in self.chunks]
embeddings = _compute_embeddings(texts_to_embed, self.embedding_model)
# Extract string IDs for the backend
embeddings = compute_embeddings(texts_to_embed, self.embedding_model)
string_ids = [chunk["id"] for chunk in self.chunks]
# Build the vector index
current_backend_kwargs = self.backend_kwargs.copy()
current_backend_kwargs['dimensions'] = self.dimensions
current_backend_kwargs = {**self.backend_kwargs, 'dimensions': self.dimensions}
builder_instance = self.backend_factory.builder(**current_backend_kwargs)
builder_instance.build(embeddings, string_ids, index_path, **current_backend_kwargs)
# Create the lightweight meta.json file
leann_meta_path = index_dir / f"{index_name}.meta.json"
meta_data = {
"version": "1.0",
"backend_name": self.backend_name,
"embedding_model": self.embedding_model,
"dimensions": self.dimensions,
"backend_kwargs": self.backend_kwargs,
"passage_sources": [
{
"type": "jsonl",
"path": str(passages_file),
"index_path": str(offset_file)
}
]
"version": "1.0", "backend_name": self.backend_name, "embedding_model": self.embedding_model,
"dimensions": self.dimensions, "backend_kwargs": self.backend_kwargs,
"passage_sources": [{"type": "jsonl", "path": str(passages_file), "index_path": str(offset_file)}]
}
with open(leann_meta_path, 'w', encoding='utf-8') as f:
json.dump(meta_data, f, indent=2)
print(f"INFO: Leann metadata saved to {leann_meta_path}")
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute # Pruned only if compact and recompute
with open(leann_meta_path, 'w', encoding='utf-8') as f: json.dump(meta_data, f, indent=2)
class LeannSearcher:
"""
The searcher is responsible for loading the index and performing the search.
It will also load the metadata of the index.
"""
def __init__(self, index_path: str, **backend_kwargs):
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
if not leann_meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}. Was the index built with LeannBuilder?")
with open(leann_meta_path, 'r', encoding='utf-8') as f:
self.meta_data = json.load(f)
meta_path_str = f"{index_path}.meta.json"
if not Path(meta_path_str).exists(): raise FileNotFoundError(f"Leann metadata file not found at {meta_path_str}")
with open(meta_path_str, 'r', encoding='utf-8') as f: self.meta_data = json.load(f)
backend_name = self.meta_data['backend_name']
self.embedding_model = self.meta_data['embedding_model']
# Initialize the passage manager
passage_sources = self.meta_data.get('passage_sources', [])
self.passage_manager = PassageManager(passage_sources)
self.passage_manager = PassageManager(self.meta_data.get('passage_sources', []))
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' (from index file) not found or not registered.")
final_kwargs = backend_kwargs.copy()
final_kwargs['meta'] = self.meta_data
if backend_factory is None: raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get('backend_kwargs', {}), **backend_kwargs}
self.backend_impl = backend_factory.searcher(index_path, **final_kwargs)
print(f"INFO: LeannSearcher initialized with '{backend_name}' backend using index '{index_path}'.")
def search(self, query: str, top_k: int = 5, **search_kwargs):
query_embedding = _compute_embeddings([query], self.embedding_model)
def search(self, query: str, top_k: int = 5, **search_kwargs) -> List[SearchResult]:
print(f"🔍 DEBUG LeannSearcher.search() called:")
print(f" Query: '{query}'")
print(f" Top_k: {top_k}")
print(f" Search kwargs: {search_kwargs}")
query_embedding = compute_embeddings([query], self.embedding_model)
print(f" Generated embedding shape: {query_embedding.shape}")
print(f"🔍 DEBUG Query embedding first 10 values: {query_embedding[0][:10]}")
print(f"🔍 DEBUG Query embedding norm: {np.linalg.norm(query_embedding[0])}")
search_kwargs['embedding_model'] = self.embedding_model
results = self.backend_impl.search(query_embedding, top_k, **search_kwargs)
print(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
enriched_results = []
for string_id, dist in zip(results['labels'][0], results['distances'][0]):
try:
passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult(
id=string_id,
score=dist,
text=passage_data['text'],
metadata=passage_data.get('metadata', {})
))
except KeyError:
print(f"WARNING: Passage ID '{string_id}' not found in passage files")
if 'labels' in results and 'distances' in results:
print(f" Processing {len(results['labels'][0])} passage IDs:")
for i, (string_id, dist) in enumerate(zip(results['labels'][0], results['distances'][0])):
try:
passage_data = self.passage_manager.get_passage(string_id)
enriched_results.append(SearchResult(
id=string_id, score=dist, text=passage_data['text'], metadata=passage_data.get('metadata', {})
))
print(f" {i+1}. passage_id='{string_id}' -> SUCCESS: {passage_data['text'][:60]}...")
except KeyError:
print(f" {i+1}. passage_id='{string_id}' -> ERROR: Passage not found in PassageManager!")
print(f" Final enriched results: {len(enriched_results)} passages")
return enriched_results
from .chat import get_llm
class LeannChat:
"""
The chat is responsible for the conversation with the LLM.
It will use the searcher to get the results and then use the LLM to generate the response.
"""
def __init__(self, index_path: str, backend_name: Optional[str] = None, llm_model: str = "gpt-4o", **kwargs):
if backend_name is None:
leann_meta_path = Path(index_path).parent / f"{Path(index_path).name}.meta.json"
if not leann_meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {leann_meta_path}.")
with open(leann_meta_path, 'r', encoding='utf-8') as f:
meta_data = json.load(f)
backend_name = meta_data['backend_name']
def __init__(self, index_path: str, llm_config: Optional[Dict[str, Any]] = None, **kwargs):
self.searcher = LeannSearcher(index_path, **kwargs)
self.llm_model = llm_model
def ask(self, question: str, top_k=5, **kwargs):
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
self.llm = get_llm(llm_config)
def ask(self, question: str, top_k=5, **kwargs):
results = self.searcher.search(question, top_k=top_k, **kwargs)
context = "\n\n".join([r.text for r in results])
prompt = f"Context:\n{context}\n\nQuestion: {question}\n\nAnswer:"
print(f"DEBUG: Calling LLM with prompt: {prompt}...")
try:
client = _get_openai_client()
response = client.chat.completions.create(
model=self.llm_model,
messages=[
{"role": "system", "content": "You are a helpful assistant that answers questions based on the provided context."},
{"role": "user", "content": prompt}
]
)
return response.choices[0].message.content
except Exception as e:
print(f"ERROR: Failed to call OpenAI API: {e}")
return f"Error: Could not get a response from the LLM. {e}"
def start_interactive(self):
print("\nLeann Chat started (type 'quit' to exit)")
while True:
try:
user_input = input("You: ").strip()
if user_input.lower() in ['quit', 'exit']:
break
if not user_input:
continue
response = self.ask(user_input)
print(f"Leann: {response}")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
return self.llm.ask(prompt, **kwargs.get("llm_kwargs", {}))

View File

@@ -73,15 +73,17 @@ class EmbeddingServerManager:
self.server_process = subprocess.Popen(
command,
cwd=project_root,
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring
text=True,
encoding='utf-8'
encoding='utf-8',
bufsize=1, # Line buffered
universal_newlines=True
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.")
@@ -90,7 +92,7 @@ class EmbeddingServerManager:
return True
if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor()
self._print_recent_output()
return False
time.sleep(wait_interval)
@@ -102,19 +104,32 @@ class EmbeddingServerManager:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
# Read any available output
import select
import sys
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:
print(f"[{self.backend_module_name} OUTPUT]: {output}")
except Exception as e:
print(f"Error reading server output: {e}")
def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr."""
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[{self.backend_module_name} LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[{self.backend_module_name} ERROR]: {line.strip()}")
self.server_process.stderr.close()
while True:
line = self.server_process.stdout.readline()
if not line:
break
print(f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True)
except Exception as e:
print(f"Log monitor error: {e}")