Files
LEANN/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py
yichuan520030910320 46f6cc100b Initial commit
2025-06-30 09:05:05 +00:00

299 lines
12 KiB
Python

import numpy as np
import os
import json
import struct
from pathlib import Path
from typing import Dict
import contextlib
import threading
import time
import atexit
import socket
import subprocess
import sys
from leann.registry import register_backend
from leann.interface import (
LeannBackendFactoryInterface,
LeannBackendBuilderInterface,
LeannBackendSearcherInterface
)
from . import _diskannpy as diskannpy
METRIC_MAP = {
"mips": diskannpy.Metric.INNER_PRODUCT,
"l2": diskannpy.Metric.L2,
"cosine": diskannpy.Metric.COSINE,
}
@contextlib.contextmanager
def chdir(path):
original_dir = os.getcwd()
os.chdir(path)
try:
yield
finally:
os.chdir(original_dir)
def _write_vectors_to_bin(data: np.ndarray, file_path: str):
num_vectors, dim = data.shape
with open(file_path, 'wb') as f:
f.write(struct.pack('I', num_vectors))
f.write(struct.pack('I', dim))
f.write(data.tobytes())
def _check_port(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
class EmbeddingServerManager:
def __init__(self):
self.server_process = None
self.server_port = None
atexit.register(self.stop_server)
def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"):
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})")
return True
# 检查端口是否已被其他无关进程占用
if _check_port(port):
print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.")
return True
print(f"INFO: Starting session-level embedding server as a background process...")
try:
command = [
sys.executable,
"-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server",
"--zmq-port", str(port),
"--model-name", model_name
]
project_root = Path(__file__).parent.parent.parent.parent
print(f"INFO: Running command from project root: {project_root}")
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding='utf-8'
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
max_wait, wait_interval = 30, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print(f"✅ Embedding server is up and ready for this session.")
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
log_thread.start()
return True
if self.server_process.poll() is not None:
print("❌ ERROR: Server process terminated unexpectedly during startup.")
self._log_monitor()
return False
time.sleep(wait_interval)
print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.")
self.stop_server()
return False
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server process: {e}")
return False
def _log_monitor(self):
if not self.server_process:
return
try:
if self.server_process.stdout:
for line in iter(self.server_process.stdout.readline, ''):
print(f"[EmbeddingServer LOG]: {line.strip()}")
self.server_process.stdout.close()
if self.server_process.stderr:
for line in iter(self.server_process.stderr.readline, ''):
print(f"[EmbeddingServer ERROR]: {line.strip()}")
self.server_process.stderr.close()
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self):
if self.server_process and self.server_process.poll() is None:
print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...")
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print("INFO: Server process terminated.")
except subprocess.TimeoutExpired:
print("WARNING: Server process did not terminate gracefully, killing it.")
self.server_process.kill()
self.server_process = None
@register_backend("diskann")
class DiskannBackend(LeannBackendFactoryInterface):
@staticmethod
def builder(**kwargs) -> LeannBackendBuilderInterface:
return DiskannBuilder(**kwargs)
@staticmethod
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
path = Path(index_path)
meta_path = path.parent / f"{path.name}.meta.json"
if not meta_path.exists():
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
with open(meta_path, 'r') as f:
meta = json.load(f)
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(meta.get("embedding_model"))
dimensions = model.get_sentence_embedding_dimension()
kwargs['dimensions'] = dimensions
except ImportError:
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
except Exception as e:
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
return DiskannSearcher(index_path, **kwargs)
class DiskannBuilder(LeannBackendBuilderInterface):
def __init__(self, **kwargs):
self.build_params = kwargs
def build(self, data: np.ndarray, index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
index_dir.mkdir(parents=True, exist_ok=True)
if data.dtype != np.float32:
data = data.astype(np.float32)
if not data.flags['C_CONTIGUOUS']:
data = np.ascontiguousarray(data)
data_filename = f"{index_prefix}_data.bin"
_write_vectors_to_bin(data, index_dir / data_filename)
build_kwargs = {**self.build_params, **kwargs}
metric_str = build_kwargs.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
complexity = build_kwargs.get("complexity", 64)
graph_degree = build_kwargs.get("graph_degree", 32)
final_index_ram_limit = build_kwargs.get("search_memory_maximum", 4.0)
indexing_ram_budget = build_kwargs.get("build_memory_maximum", 8.0)
num_threads = build_kwargs.get("num_threads", 8)
pq_disk_bytes = build_kwargs.get("pq_disk_bytes", 0)
codebook_prefix = ""
print(f"INFO: Building DiskANN index for {data.shape[0]} vectors with metric {metric_enum}...")
try:
with chdir(index_dir):
diskannpy.build_disk_float_index(
metric_enum,
data_filename,
index_prefix,
complexity,
graph_degree,
final_index_ram_limit,
indexing_ram_budget,
num_threads,
pq_disk_bytes,
codebook_prefix
)
print(f"✅ DiskANN index built successfully at '{index_dir / index_prefix}'")
except Exception as e:
print(f"💥 ERROR: DiskANN index build failed. Exception: {e}")
raise
finally:
temp_data_file = index_dir / data_filename
if temp_data_file.exists():
os.remove(temp_data_file)
class DiskannSearcher(LeannBackendSearcherInterface):
def __init__(self, index_path: str, **kwargs):
path = Path(index_path)
index_dir = path.parent
index_prefix = path.stem
metric_str = kwargs.get("distance_metric", "mips").lower()
metric_enum = METRIC_MAP.get(metric_str)
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
num_threads = kwargs.get("num_threads", 8)
num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0)
dimensions = kwargs.get("dimensions")
if not dimensions:
raise ValueError("Vector dimension not provided to DiskannSearcher.")
try:
full_index_prefix = str(index_dir / index_prefix)
self._index = diskannpy.StaticDiskFloatIndex(
metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", ""
)
self.num_threads = num_threads
self.embedding_server_manager = EmbeddingServerManager()
print("✅ DiskANN index loaded successfully.")
except Exception as e:
print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}")
raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
complexity = kwargs.get("complexity", 100)
beam_width = kwargs.get("beam_width", 4)
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)
skip_search_reorder = kwargs.get("skip_search_reorder", False)
recompute_beighbor_embeddings = kwargs.get("recompute_beighbor_embeddings", False)
dedup_node_dis = kwargs.get("dedup_node_dis", False)
prune_ratio = kwargs.get("prune_ratio", 0.0)
batch_recompute = kwargs.get("batch_recompute", False)
global_pruning = kwargs.get("global_pruning", False)
if recompute_beighbor_embeddings:
print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running")
zmq_port = kwargs.get("zmq_port", 5555)
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
if not self.embedding_server_manager.start_server(zmq_port, embedding_model):
print(f"WARNING: Failed to start embedding server, falling back to PQ computation")
kwargs['recompute_beighbor_embeddings'] = False
if query.dtype != np.float32:
query = query.astype(np.float32)
if query.ndim == 1:
query = np.expand_dims(query, axis=0)
try:
labels, distances = self._index.batch_search(
query,
query.shape[0],
top_k,
complexity,
beam_width,
self.num_threads,
USE_DEFERRED_FETCH,
skip_search_reorder,
recompute_beighbor_embeddings,
dedup_node_dis,
prune_ratio,
batch_recompute,
global_pruning
)
return {"labels": labels, "distances": distances}
except Exception as e:
print(f"💥 ERROR: DiskANN search failed. Exception: {e}")
batch_size = query.shape[0]
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
def __del__(self):
if hasattr(self, 'embedding_server_manager'):
self.embedding_server_manager.stop_server()