refactor: passage structure

This commit is contained in:
Andy Lee
2025-07-06 21:48:38 +00:00
parent 5611f708e9
commit 16705fc44a
6 changed files with 289 additions and 139 deletions

View File

@@ -3,7 +3,7 @@ import os
import json
import struct
from pathlib import Path
from typing import Dict, Any
from typing import Dict, Any, List
import contextlib
import threading
import time
@@ -11,6 +11,7 @@ import atexit
import socket
import subprocess
import sys
import pickle
from leann.embedding_server_manager import EmbeddingServerManager
from .convert_to_csr import convert_hnsw_graph_to_csr
@@ -74,7 +75,7 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, index_path: str, **kwargs):
def build(self, data: np.ndarray, ids: List[str], index_path: str, **kwargs):
"""Build HNSW index using FAISS"""
from . import faiss
@@ -89,6 +90,12 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
# Create label map: integer -> string_id
label_map = {i: str_id for i, str_id in enumerate(ids)}
label_map_file = index_dir / "leann.labels.map"
with open(label_map_file, 'wb') as f:
pickle.dump(label_map, f)
metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str)
if metric_enum is None:
@@ -119,9 +126,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_compact:
self._convert_to_csr(index_file)
if self.is_recompute:
self._generate_passages_file(index_dir, index_prefix, **kwargs)
except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise
@@ -155,30 +159,6 @@ class HNSWBuilder(LeannBackendBuilderInterface):
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode"""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation")
return
# Generate node_id to text mapping
passages_data = {}
for node_id, chunk in enumerate(chunks):
passages_data[str(node_id)] = chunk["text"]
# Save passages file
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
# Don't raise - this is not critical for index building
pass
class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
@@ -282,6 +262,14 @@ class HNSWSearcher(LeannBackendSearcherInterface):
self.index_dir = path.parent
self.index_prefix = path.stem
# Load the label map
label_map_file = self.index_dir / "leann.labels.map"
if not label_map_file.exists():
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
with open(label_map_file, 'rb') as f:
self.label_map = pickle.load(f)
index_file = self.index_dir / f"{self.index_prefix}.index"
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
@@ -336,12 +324,13 @@ class HNSWSearcher(LeannBackendSearcherInterface):
passages_file = kwargs.get("passages_file")
if not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Automatically found passages file: {passages_file}")
# Get the passages file path from meta.json
if 'passage_sources' in self.meta and self.meta['passage_sources']:
passage_source = self.meta['passage_sources'][0]
passages_file = passage_source['path']
print(f"INFO: Found passages file from metadata: {passages_file}")
else:
raise RuntimeError(f"FATAL: Index is pruned but no passages file found.")
raise RuntimeError(f"FATAL: Index is pruned but no passage_sources found in metadata.")
zmq_port = kwargs.get("zmq_port", 5557)
server_started = self.embedding_server_manager.start_server(
@@ -372,7 +361,18 @@ class HNSWSearcher(LeannBackendSearcherInterface):
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels), params)
return {"labels": labels, "distances": distances}
# Convert integer labels to string IDs
string_labels = []
for batch_labels in labels:
batch_string_labels = []
for int_label in batch_labels:
if int_label in self.label_map:
batch_string_labels.append(self.label_map[int_label])
else:
batch_string_labels.append(f"unknown_{int_label}")
string_labels.append(batch_string_labels)
return {"labels": string_labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}")

View File

@@ -58,21 +58,46 @@ class SimplePassageLoader:
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
"""
Load passages from a JSON file
Expected format: {"passage_id": "passage_text", ...}
Load passages from a JSONL file with label map support
Expected format: {"id": "passage_id", "text": "passage_text", "metadata": {...}} (one per line)
"""
if not os.path.exists(passages_file):
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
return SimplePassageLoader()
raise FileNotFoundError(f"Passages file {passages_file} not found.")
try:
with open(passages_file, 'r', encoding='utf-8') as f:
passages_data = json.load(f)
print(f"Loaded {len(passages_data)} passages from {passages_file}")
return SimplePassageLoader(passages_data)
except Exception as e:
print(f"Error loading passages from {passages_file}: {e}")
return SimplePassageLoader()
if not passages_file.endswith('.jsonl'):
raise ValueError(f"Expected .jsonl file format, got: {passages_file}")
# Load label map (int -> string_id)
passages_dir = Path(passages_file).parent
label_map_file = passages_dir / "leann.labels.map"
label_map = {}
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
# Load passages by string ID
string_id_passages = {}
with open(passages_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
passage = json.loads(line)
string_id_passages[passage['id']] = passage['text']
# Create int ID -> text mapping using label map
passages_data = {}
for int_id, string_id in label_map.items():
if string_id in string_id_passages:
passages_data[str(int_id)] = string_id_passages[string_id]
else:
print(f"WARNING: String ID {string_id} from label map not found in passages")
print(f"Loaded {len(passages_data)} passages from JSONL file {passages_file} using label map")
return SimplePassageLoader(passages_data)
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,