Files
LEANN/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

313 lines
12 KiB
Python

import numpy as np
import os
import json
import struct
from pathlib import Path
from typing import Dict
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
# 文件: packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
# ... (其他 import 保持不变) ...
from leann.registry import register_backend
from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface
)
def get_metric_map():
from . import faiss
return {
"mips": faiss.METRIC_INNER_PRODUCT,
"l2": faiss.METRIC_L2,
"cosine": faiss.METRIC_INNER_PRODUCT, # Will need normalization
}
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
class HNSWEmbeddingServerManager:
"""
HNSW-specific embedding server manager that handles the lifecycle of the embedding server process.
Mirrors the DiskANN EmbeddingServerManager architecture.
"""
def __init__(self):
self.server_process = None
self.server_port = None
atexit.register(self.stop_server)
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None):
"""
Start the HNSW embedding server process.
Args:
port: ZMQ port for the server
model_name: Name of the embedding model to use
passages_file: Optional path to passages JSON file
"""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
return True
# Check if port is already in use
if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external HNSW server is running and connecting to it.")
return True
print(f"INFO: Starting session-level HNSW embedding server as a background process...")
try:
command = [
sys.executable,
"-m", "packages.leann-backend-hnsw.src.leann_backend_hnsw.hnsw_embedding_server",
"--zmq-port", str(port),
"--model-name", model_name
]
# Add passages file if provided
if passages_file:
command.extend(["--passages-file", str(passages_file)])
project_root = Path(__file__).parent.parent.parent.parent
print(f"INFO: Running HNSW command from project root: {project_root}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding='utf-8'
)
self.server_port = port
print(f"INFO: HNSW server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ HNSW embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: HNSW server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: HNSW server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
except Exception as e:
print(f"❌ ERROR: Failed to start HNSW embedding server process: {e}")
return False
def _log_monitor(self):
"""Monitor server logs"""
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[HNSWEmbeddingServer LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[HNSWEmbeddingServer ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e:
print(f"HNSW Log monitor error: {e}")
def stop_server(self):
"""Stop the HNSW embedding server process"""
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating HNSW session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: HNSW server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: HNSW server process did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None
@register_backend("hnsw")
class HNSWBackend(LeannBackendFactoryInterface):
@staticmethod
def builder(**kwargs) -> LeannBackendBuilderInterface:
return HNSWBuilder(**kwargs)
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.stem}.hnsw.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
with open(meta_path, 'r') as f:
meta = json.load(f)
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(meta.get("embedding_model"))
dimensions = model.get_sentence_embedding_dimension()
kwargs['dimensions'] = dimensions
except ImportError:
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
except Exception as e:
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
return HNSWSearcher(index_path, **kwargs)
class HNSWBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def build(self, data: np.ndarray, index_path: str, **kwargs):
"""Build HNSW index using FAISS"""
from . import faiss
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
# HNSW parameters
M = build_kwargs.get("M", 32) # Max connections per layer
efConstruction = build_kwargs.get("efConstruction", 200) # Size of the dynamic candidate list for construction
dim = data.shape[1]
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
# Create HNSW index
if metric_enum == faiss.METRIC_INNER_PRODUCT:
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
else: # L2
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
# Set construction parameters
index.hnsw.efConstruction = efConstruction
# Normalize vectors if using cosine similarity
if metric_str == "cosine":
faiss.normalize_L2(data)
# Add vectors to index
index.add(data.shape[0], faiss.swig_ptr(data))
# Save index
index_file = index_dir / f"{index_prefix}.index"
faiss.write_index(index, str(index_file))
print(f"✅ HNSW index built successfully at '{index_file}'")
except Exception as e:
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
raise
class HNSWSearcher(LeannBackendSearcherInterface):
def __init__(self, index_path: str, **kwargs):
from . import faiss
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
metric_str = kwargs.get("distance_metric", "mips").lower()
metric_enum = get_metric_map().get(metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
dimensions = kwargs.get("dimensions")
if not dimensions:
raise ValueError("Vector dimension not provided to HNSWSearcher.")
try:
# Load FAISS HNSW index
index_file = index_dir / f"{index_prefix}.index"
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
self._index = faiss.read_index(str(index_file))
self.metric_str = metric_str
self.embedding_server_manager = HNSWEmbeddingServerManager()
print("✅ HNSW index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load HNSW index. Exception: {e}")
raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
"""Search using HNSW index with optional recompute functionality"""
ef = kwargs.get("ef", 200) # Size of the dynamic candidate list for search
# Recompute parameters
recompute_neighbor_embeddings = kwargs.get("recompute_neighbor_embeddings", False)
zmq_port = kwargs.get("zmq_port", 5556)
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
passages_file = kwargs.get("passages_file", None)
if recompute_neighbor_embeddings:
print(f"INFO: HNSW ZMQ mode enabled - ensuring embedding server is running")
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file):
print(f"WARNING: Failed to start HNSW embedding server, falling back to standard search")
kwargs['recompute_neighbor_embeddings'] = False
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
# Normalize query if using cosine similarity
if self.metric_str == "cosine":
faiss.normalize_L2(query)
try:
# Set search parameter
self._index.hnsw.efSearch = ef
if recompute_neighbor_embeddings:
# Use custom search with recompute
# This would require implementing custom HNSW search logic
# For now, we'll fall back to standard search
print("WARNING: Recompute functionality for HNSW not yet implemented, using standard search")
distances, labels = self._index.search(query, top_k)
else:
# Standard FAISS search
distances, labels = self._index.search(query, top_k)
return {"labels": labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()