feat: hnsw embedding server and csr format

This commit is contained in:
Andy Lee
2025-07-05 23:04:41 +00:00
parent 368474d036
commit 0aa84e147b
9 changed files with 959 additions and 154 deletions

View File

@@ -3,7 +3,7 @@ import os
import json
import struct
from pathlib import Path
from typing import Dict
from typing import Dict, Any
import contextlib
import threading
import time
@@ -12,9 +12,7 @@ import socket
import subprocess
import sys
# 文件: packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
# ... (其他 import 保持不变) ...
from .convert_to_csr import convert_hnsw_graph_to_csr
from leann.registry import register_backend
from leann.interface import (
@@ -28,7 +26,7 @@ def get_metric_map():
return {
"mips": faiss.METRIC_INNER_PRODUCT,
"l2": faiss.METRIC_L2,
"cosine": faiss.METRIC_INNER_PRODUCT, # Will need normalization
"cosine": faiss.METRIC_INNER_PRODUCT,
}
def _check_port(port: int) -> bool:
@@ -69,12 +67,11 @@ class HNSWEmbeddingServerManager:
try:
command = [
sys.executable,
"-m", "packages.leann-backend-hnsw.src.leann_backend_hnsw.hnsw_embedding_server",
"-m", "leann_backend_hnsw.hnsw_embedding_server",
"--zmq-port", str(port),
"--model-name", model_name
]
# Add passages file if provided
if passages_file:
command.extend(["--passages-file", str(passages_file)])
@@ -172,7 +169,29 @@ class HNSWBackend(LeannBackendFactoryInterface):
class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
self.build_params = kwargs.copy()
# --- Configuration defaults with standardized names ---
# Apply defaults and write them back to the build_params dict
# so they can be saved in the metadata file by LeannBuilder.
self.is_compact = self.build_params.setdefault("is_compact", True)
self.is_recompute = self.build_params.setdefault("is_recompute", True) # Default: prune embeddings
# --- Additional Options ---
self.is_skip_neighbors = self.build_params.setdefault("is_skip_neighbors", False)
self.disk_cache_ratio = self.build_params.setdefault("disk_cache_ratio", 0.0)
self.external_storage_path = self.build_params.get("external_storage_path", None)
# --- Standard HNSW parameters ---
self.M = self.build_params.setdefault("M", 32)
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
if self.is_skip_neighbors and not self.is_compact:
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if self.is_recompute and not self.is_compact:
raise ValueError("is_recompute requires is_compact=True for efficiency")
def build(self, data: np.ndarray, index_path: str, **kwargs):
"""Build HNSW index using FAISS"""
@@ -189,97 +208,297 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower()
metric_str = self.distance_metric.lower()
metric_enum = get_metric_map().get(metric_str)
print('metric_enum', metric_enum,' metric_str', metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
# HNSW parameters
M = build_kwargs.get("M", 32) # Max connections per layer
efConstruction = build_kwargs.get("efConstruction", 200) # Size of the dynamic candidate list for construction
M = self.M
efConstruction = self.efConstruction
dim = data.shape[1]
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
# Create HNSW index
if metric_enum == faiss.METRIC_INNER_PRODUCT:
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
else: # L2
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
# Set construction parameters
index.hnsw.efConstruction = efConstruction
# Normalize vectors if using cosine similarity
if metric_str == "cosine":
faiss.normalize_L2(data)
# Add vectors to index
print('starting to add vectors to index')
index.add(data.shape[0], faiss.swig_ptr(data))
print('vectors added to index')
# Save index
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
print(f"✅ HNSW index built successfully at '{index_file}'")
if self.is_compact:
self._convert_to_csr(index_file)
# Generate passages file for recompute mode
if self.is_recompute:
self._generate_passages_file(index_dir, index_prefix, **kwargs)
except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
try:
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
print(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr(
str(index_file),
str(csr_temp_file),
prune_embeddings=self.is_recompute
)
if success:
print("✅ CSR conversion successful.")
import shutil
shutil.move(str(csr_temp_file), str(index_file))
print(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
except Exception as e:
print(f"💥 ERROR: CSR conversion failed. Exception: {e}")
raise
def _generate_passages_file(self, index_dir: Path, index_prefix: str, **kwargs):
"""Generate passages file for recompute mode"""
try:
chunks = kwargs.get('chunks', [])
if not chunks:
print("INFO: No chunks data provided, skipping passages file generation")
return
# Generate node_id to text mapping
passages_data = {}
for node_id, chunk in enumerate(chunks):
passages_data[str(node_id)] = chunk["text"]
# Save passages file
passages_file = index_dir / f"{index_prefix}.passages.json"
with open(passages_file, 'w', encoding='utf-8') as f:
json.dump(passages_data, f, ensure_ascii=False, indent=2)
print(f"✅ Generated passages file for recompute mode at '{passages_file}' ({len(passages_data)} passages)")
except Exception as e:
print(f"💥 ERROR: Failed to generate passages file. Exception: {e}")
# Don't raise - this is not critical for index building
pass
class HNSWSearcher(LeannBackendSearcherInterface):
def _get_index_storage_status(self, index_file: Path) -> tuple[bool, bool]:
"""
Robustly determines the index's storage status by parsing the file.
Returns:
A tuple (is_compact, is_pruned).
"""
if not index_file.exists():
return False, False
with open(index_file, 'rb') as f:
try:
def read_struct(fmt):
size = struct.calcsize(fmt)
data = f.read(size)
if len(data) != size:
raise EOFError(f"File ended unexpectedly reading struct fmt '{fmt}'.")
return struct.unpack(fmt, data)[0]
def skip_vector(element_size):
count = read_struct('<Q')
f.seek(count * element_size, 1)
# 1. Read up to the compact flag
read_struct('<I'); read_struct('<i'); read_struct('<q');
read_struct('<q'); read_struct('<q'); read_struct('<?')
metric_type = read_struct('<i')
if metric_type > 1: read_struct('<f')
skip_vector(8); skip_vector(4); skip_vector(4)
# 2. Check if there's a compact flag byte
# Try to read the compact flag, but handle both old and new formats
pos_before_compact = f.tell()
try:
is_compact = read_struct('<?')
print(f"INFO: Detected is_compact flag as: {is_compact}")
except (EOFError, struct.error):
# Old format without compact flag - assume non-compact
f.seek(pos_before_compact)
is_compact = False
print(f"INFO: No compact flag found, assuming is_compact=False")
# 3. Read storage FourCC to determine if pruned
is_pruned = False
try:
if is_compact:
# For compact, we need to skip pointers and scalars to get to the storage FourCC
skip_vector(8) # level_ptr
skip_vector(8) # node_offsets
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
storage_fourcc = read_struct('<I')
else:
# For non-compact, we need to read the flag probe, then skip offsets and neighbors
pos_before_probe = f.tell()
flag_byte = f.read(1)
if not (flag_byte and flag_byte == b'\x00'):
f.seek(pos_before_probe)
skip_vector(8); skip_vector(4) # offsets, neighbors
read_struct('<i'); read_struct('<i'); read_struct('<i');
read_struct('<i'); read_struct('<i')
# Now we are at the storage. The entire rest is storage blob.
storage_fourcc = struct.unpack('<I', f.read(4))[0]
NULL_INDEX_FOURCC = int.from_bytes(b'null', 'little')
if storage_fourcc == NULL_INDEX_FOURCC:
is_pruned = True
except (EOFError, struct.error):
# Cannot determine pruning status, assume not pruned
pass
print(f"INFO: Detected is_pruned as: {is_pruned}")
return is_compact, is_pruned
except (EOFError, struct.error) as e:
print(f"WARNING: Could not parse index file to detect format: {e}. Assuming standard, not pruned.")
return False, False
def __init__(self, index_path: str, **kwargs):
from . import faiss
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
metric_str = kwargs.get("distance_metric", "mips").lower()
# Store configuration and paths for later use
self.config = kwargs.copy()
self.config["index_path"] = index_path
self.index_dir = index_dir
self.index_prefix = index_prefix
metric_str = self.config.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
dimensions = kwargs.get("dimensions")
dimensions = self.config.get("dimensions")
if not dimensions:
raise ValueError("Vector dimension not provided to HNSWSearcher.")
try:
# Load FAISS HNSW index
index_file = index_dir / f"{index_prefix}.index"
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
self._index = faiss.read_index(str(index_file))
self.metric_str = metric_str
self.embedding_server_manager = HNSWEmbeddingServerManager()
print("✅ HNSW index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load HNSW index. Exception: {e}")
raise
index_file = index_dir / f"{index_prefix}.index"
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
self.is_compact, self.is_pruned = self._get_index_storage_status(index_file)
# Validate configuration constraints
if not self.is_compact and self.config.get("is_skip_neighbors", False):
raise ValueError("is_skip_neighbors can only be used with is_compact=True")
if self.config.get("is_recompute", False) and self.config.get("external_storage_path"):
raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously")
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
# Apply additional configuration options with strict validation
hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False)
# If index is pruned, force recompute mode regardless of user setting
hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False)
hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0)
hnsw_config.external_storage_path = self.config.get("external_storage_path")
hnsw_config.zmq_port = self.config.get("zmq_port", 5557)
# CRITICAL ASSERTION: If index is pruned, recompute MUST be enabled
if self.is_pruned and not hnsw_config.is_recompute:
raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.")
print(f"INFO: Loading index with is_compact={self.is_compact}, is_pruned={self.is_pruned}")
print(f"INFO: Config - skip_neighbors={hnsw_config.is_skip_neighbors}, recompute={hnsw_config.is_recompute}")
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
if self.is_compact:
print("✅ Compact CSR format HNSW index loaded successfully.")
else:
print("✅ Standard HNSW index loaded successfully.")
self.metric_str = metric_str
self.embedding_server_manager = HNSWEmbeddingServerManager()
def _get_index_file(self, index_dir: Path, index_prefix: str) -> Path:
"""Get the appropriate index file path based on format"""
# We always use the same filename now, format is detected internally
return index_dir / f"{index_prefix}.index"
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]:
"""Search using HNSW index with optional recompute functionality"""
from . import faiss
ef = kwargs.get("ef", 200) # Size of the dynamic candidate list for search
# Merge config with search-time kwargs
search_config = self.config.copy()
search_config.update(kwargs)
ef = search_config.get("ef", 200) # Size of the dynamic candidate list for search
# Recompute parameters
recompute_neighbor_embeddings = kwargs.get("recompute_neighbor_embeddings", False)
zmq_port = kwargs.get("zmq_port", 5556)
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
passages_file = kwargs.get("passages_file", None)
zmq_port = search_config.get("zmq_port", 5557)
embedding_model = search_config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
passages_file = search_config.get("passages_file", None)
if recompute_neighbor_embeddings:
print(f"INFO: HNSW ZMQ mode enabled - ensuring embedding server is running")
# For recompute mode, try to find the passages file automatically
if self.is_pruned and not passages_file:
potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json"
print(f"DEBUG: Checking for passages file at: {potential_passages_file}")
if potential_passages_file.exists():
passages_file = str(potential_passages_file)
print(f"INFO: Found passages file for recompute mode: {passages_file}")
else:
print(f"WARNING: No passages file found for recompute mode at {potential_passages_file}")
# If index is pruned (embeddings removed), we MUST start embedding server for recompute
if self.is_pruned:
print(f"INFO: Index is pruned - starting embedding server for recompute")
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file):
print(f"WARNING: Failed to start HNSW embedding server, falling back to standard search")
kwargs['recompute_neighbor_embeddings'] = False
# CRITICAL: Check passages file exists - fail fast if not
if not passages_file:
raise RuntimeError(f"FATAL: Index is pruned but no passages file found. Cannot proceed with recompute mode.")
# Check if server is already running first
if _check_port(zmq_port):
print(f"INFO: Embedding server already running on port {zmq_port}")
else:
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file):
raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}")
# Give server extra time to fully initialize
print(f"INFO: Waiting for embedding server to fully initialize...")
time.sleep(3)
# Final verification
if not _check_port(zmq_port):
raise RuntimeError(f"Embedding server failed to start listening on port {zmq_port}")
else:
print(f"INFO: Index has embeddings stored - no recompute needed")
if query.dtype != np.float32:
query = query.astype(np.float32)
@@ -299,23 +518,14 @@ class HNSWSearcher(LeannBackendSearcherInterface):
distances = np.empty((batch_size, top_k), dtype=np.float32)
labels = np.empty((batch_size, top_k), dtype=np.int64)
if recompute_neighbor_embeddings:
# Use custom search with recompute
# This would require implementing custom HNSW search logic
# For now, we'll fall back to standard search
print("WARNING: Recompute functionality for HNSW not yet implemented, using standard search")
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
else:
# Standard FAISS search using SWIG API
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
# Use standard FAISS search - recompute is handled internally by FAISS
self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels))
return {"labels": labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
raise
def __del__(self):
if hasattr(self, 'embedding_server_manager'):