Files

288 lines
11 KiB
Python

import logging
import os
import shutil
import time
from pathlib import Path
from typing import Any, Literal, Optional
import numpy as np
from leann.interface import (
LeannBackendBuilderInterface,
LeannBackendFactoryInterface,
LeannBackendSearcherInterface,
)
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
def get_metric_map():
from . import faiss # type: ignore
return {
"mips": faiss.METRIC_INNER_PRODUCT,
"l2": faiss.METRIC_L2,
"cosine": faiss.METRIC_INNER_PRODUCT,
}
def normalize_l2(data: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(data, axis=1, keepdims=True)
norms[norms == 0] = 1 # Avoid division by zero
return data / norms
@register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod
def builder(**kwargs) -> LeannBackendBuilderInterface:
return HNSWBuilder(**kwargs)
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs.copy()
self.is_compact = self.build_params.setdefault("is_compact", True)
self.is_recompute = self.build_params.setdefault("is_recompute", True)
self.M = self.build_params.setdefault("M", 32)
self.efConstruction = self.build_params.setdefault("efConstruction", 200)
self.distance_metric = self.build_params.setdefault("distance_metric", "mips")
self.dimensions = self.build_params.get("dimensions")
if not self.is_recompute and self.is_compact:
# Auto-correct: non-recompute requires non-compact storage for HNSW
logger.warning(
"is_recompute=False requires non-compact HNSW. Forcing is_compact=False."
)
self.is_compact = False
self.build_params["is_compact"] = False
def build(self, data: np.ndarray, ids: list[str], index_path: str, **kwargs):
from . import faiss # type: ignore
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
logger.warning(f"Converting data to float32, shape: {data.shape}")
data = data.astype(np.float32)
metric_enum = get_metric_map().get(self.distance_metric.lower())
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
dim = self.dimensions or data.shape[1]
index = faiss.IndexHNSWFlat(dim, self.M, metric_enum)
index.hnsw.efConstruction = self.efConstruction
if self.distance_metric.lower() == "cosine":
data = normalize_l2(data)
index.add(data.shape[0], faiss.swig_ptr(data))
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
# Persist ID map so searcher can map FAISS integer labels back to passage IDs
try:
idmap_file = index_dir / f"{index_prefix}.ids.txt"
with open(idmap_file, "w", encoding="utf-8") as f:
for id_str in ids:
f.write(str(id_str) + "\n")
except Exception as e:
logger.warning(f"Failed to write ID map: {e}")
if self.is_compact:
self._convert_to_csr(index_file)
elif self.is_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
mode_str = "CSR-pruned" if self.is_recompute else "CSR-standard"
logger.info(f"INFO: Converting HNSW index to {mode_str} format...")
csr_temp_file = index_file.with_suffix(".csr.tmp")
success = convert_hnsw_graph_to_csr(
str(index_file), str(csr_temp_file), prune_embeddings=self.is_recompute
)
if success:
logger.info("✅ CSR conversion successful.")
# index_file_old = index_file.with_suffix(".old")
# shutil.move(str(index_file), str(index_file_old))
shutil.move(str(csr_temp_file), str(index_file))
logger.info(f"INFO: Replaced original index with {mode_str} version at '{index_file}'")
else:
# Clean up and fail fast
if csr_temp_file.exists():
os.remove(csr_temp_file)
raise RuntimeError("CSR conversion failed - cannot proceed with compact format")
class HNSWSearcher(BaseSearcher):
def __init__(self, index_path: str, **kwargs):
super().__init__(
index_path,
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
**kwargs,
)
from . import faiss # type: ignore
self.distance_metric = (
self.meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
)
metric_enum = get_metric_map().get(self.distance_metric)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
hnsw_config = faiss.HNSWIndexConfig()
hnsw_config.is_compact = self.is_compact
hnsw_config.is_recompute = (
self.is_pruned
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
# Load ID map if available
self._id_map: list[str] = []
try:
idmap_file = self.index_dir / f"{self.index_path.stem}.ids.txt"
if idmap_file.exists():
with open(idmap_file, encoding="utf-8") as f:
self._id_map = [line.rstrip("\n") for line in f]
except Exception as e:
logger.warning(f"Failed to load ID map: {e}")
def search(
self,
query: np.ndarray,
top_k: int,
zmq_port: Optional[int] = None,
complexity: int = 64,
beam_width: int = 1,
prune_ratio: float = 0.0,
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
batch_size: int = 0,
**kwargs,
) -> dict[str, Any]:
"""
Search for nearest neighbors using HNSW index.
Args:
query: Query vectors (B, D) where B is batch size, D is dimension
top_k: Number of nearest neighbors to return
complexity: Search complexity/efSearch, higher = more accurate but slower
beam_width: Number of parallel search paths/beam_size
prune_ratio: Ratio of neighbors to prune via PQ (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server
pruning_strategy: PQ candidate selection strategy:
- "global": Use global PQ queue size for selection (default)
- "local": Local pruning, sort and select best candidates
- "proportional": Base selection on new neighbor count ratio
zmq_port: ZMQ port for embedding server communication. Must be provided if recompute_embeddings is True.
batch_size: Neighbor processing batch size, 0=disabled (HNSW-specific)
**kwargs: Additional HNSW-specific parameters (for legacy compatibility)
Returns:
Dict with 'labels' (list of lists) and 'distances' (ndarray)
"""
from . import faiss # type: ignore
if not recompute_embeddings and self.is_pruned:
raise RuntimeError(
"Recompute is required for pruned/compact HNSW index. "
"Re-run search with --recompute, or rebuild with --no-recompute and --no-compact."
)
if recompute_embeddings:
if zmq_port is None:
raise ValueError("zmq_port must be provided if recompute_embeddings is True")
if query.dtype != np.float32:
query = query.astype(np.float32)
if self.distance_metric == "cosine":
query = normalize_l2(query)
params = faiss.SearchParametersHNSW()
if zmq_port is not None:
params.zmq_port = zmq_port # C++ code won't use this if recompute_embeddings is False
params.efSearch = complexity
params.beam_size = beam_width
# For OpenAI embeddings with cosine distance, disable relative distance check
# This prevents early termination when all scores are in a narrow range
embedding_model = self.meta.get("embedding_model", "").lower()
if self.distance_metric == "cosine" and any(
openai_model in embedding_model for openai_model in ["text-embedding", "openai"]
):
params.check_relative_distance = False
else:
params.check_relative_distance = True
# PQ pruning: direct mapping to HNSW's pq_pruning_ratio
params.pq_pruning_ratio = prune_ratio
# Map pruning_strategy to HNSW parameters
if pruning_strategy == "local":
params.local_prune = True
params.send_neigh_times_ratio = 0.0
elif pruning_strategy == "proportional":
params.local_prune = False
params.send_neigh_times_ratio = 1.0 # Any value > 1e-6 triggers proportional mode
else: # "global"
params.local_prune = False
params.send_neigh_times_ratio = 0.0
# HNSW-specific batch processing parameter
params.batch_size = batch_size
batch_size_query = query.shape[0]
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
search_time = time.time()
self._index.search(
query.shape[0],
faiss.swig_ptr(query),
top_k,
faiss.swig_ptr(distances),
faiss.swig_ptr(labels),
params,
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
if self._id_map:
def map_label(x: int) -> str:
if 0 <= x < len(self._id_map):
return self._id_map[x]
return str(x)
string_labels = [
[map_label(int(label)) for label in batch_labels] for batch_labels in labels
]
else:
string_labels = [
[str(int_label) for int_label in batch_labels] for batch_labels in labels
]
return {"labels": string_labels, "distances": distances}