Initial commit
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from . import hnsw_backend
|
||||
313
packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
Normal file
313
packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
import contextlib
|
||||
import threading
|
||||
import time
|
||||
import atexit
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# 文件: packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py
|
||||
|
||||
# ... (其他 import 保持不变) ...
|
||||
|
||||
from leann.registry import register_backend
|
||||
from leann.interface import (
|
||||
LeannBackendFactoryInterface,
|
||||
LeannBackendBuilderInterface,
|
||||
LeannBackendSearcherInterface
|
||||
)
|
||||
|
||||
def get_metric_map():
|
||||
from . import faiss
|
||||
return {
|
||||
"mips": faiss.METRIC_INNER_PRODUCT,
|
||||
"l2": faiss.METRIC_L2,
|
||||
"cosine": faiss.METRIC_INNER_PRODUCT, # Will need normalization
|
||||
}
|
||||
|
||||
def _check_port(port: int) -> bool:
|
||||
"""Check if a port is in use"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
class HNSWEmbeddingServerManager:
|
||||
"""
|
||||
HNSW-specific embedding server manager that handles the lifecycle of the embedding server process.
|
||||
Mirrors the DiskANN EmbeddingServerManager architecture.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.server_process = None
|
||||
self.server_port = None
|
||||
atexit.register(self.stop_server)
|
||||
|
||||
def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None):
|
||||
"""
|
||||
Start the HNSW embedding server process.
|
||||
|
||||
Args:
|
||||
port: ZMQ port for the server
|
||||
model_name: Name of the embedding model to use
|
||||
passages_file: Optional path to passages JSON file
|
||||
"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})")
|
||||
return True
|
||||
|
||||
# Check if port is already in use
|
||||
if _check_port(port):
|
||||
print(f"WARNING: Port {port} is already in use. Assuming an external HNSW server is running and connecting to it.")
|
||||
return True
|
||||
|
||||
print(f"INFO: Starting session-level HNSW embedding server as a background process...")
|
||||
|
||||
try:
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m", "packages.leann-backend-hnsw.src.leann_backend_hnsw.hnsw_embedding_server",
|
||||
"--zmq-port", str(port),
|
||||
"--model-name", model_name
|
||||
]
|
||||
|
||||
# Add passages file if provided
|
||||
if passages_file:
|
||||
command.extend(["--passages-file", str(passages_file)])
|
||||
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
print(f"INFO: Running HNSW command from project root: {project_root}")
|
||||
|
||||
self.server_process = subprocess.Popen(
|
||||
command,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.server_port = port
|
||||
print(f"INFO: HNSW server process started with PID: {self.server_process.pid}")
|
||||
|
||||
max_wait, wait_interval = 30, 0.5
|
||||
for _ in range(int(max_wait / wait_interval)):
|
||||
if _check_port(port):
|
||||
print(f"✅ HNSW embedding server is up and ready for this session.")
|
||||
log_thread = threading.Thread(target=self._log_monitor, daemon=True)
|
||||
log_thread.start()
|
||||
return True
|
||||
if self.server_process.poll() is not None:
|
||||
print("❌ ERROR: HNSW server process terminated unexpectedly during startup.")
|
||||
self._log_monitor()
|
||||
return False
|
||||
time.sleep(wait_interval)
|
||||
|
||||
print(f"❌ ERROR: HNSW server process failed to start listening within {max_wait} seconds.")
|
||||
self.stop_server()
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ ERROR: Failed to start HNSW embedding server process: {e}")
|
||||
return False
|
||||
|
||||
def _log_monitor(self):
|
||||
"""Monitor server logs"""
|
||||
if not self.server_process:
|
||||
return
|
||||
try:
|
||||
if self.server_process.stdout:
|
||||
for line in iter(self.server_process.stdout.readline, ''):
|
||||
print(f"[HNSWEmbeddingServer LOG]: {line.strip()}")
|
||||
self.server_process.stdout.close()
|
||||
if self.server_process.stderr:
|
||||
for line in iter(self.server_process.stderr.readline, ''):
|
||||
print(f"[HNSWEmbeddingServer ERROR]: {line.strip()}")
|
||||
self.server_process.stderr.close()
|
||||
except Exception as e:
|
||||
print(f"HNSW Log monitor error: {e}")
|
||||
|
||||
def stop_server(self):
|
||||
"""Stop the HNSW embedding server process"""
|
||||
if self.server_process and self.server_process.poll() is None:
|
||||
print(f"INFO: Terminating HNSW session server process (PID: {self.server_process.pid})...")
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
print("INFO: HNSW server process terminated.")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("WARNING: HNSW server process did not terminate gracefully, killing it.")
|
||||
self.server_process.kill()
|
||||
self.server_process = None
|
||||
|
||||
@register_backend("hnsw")
|
||||
class HNSWBackend(LeannBackendFactoryInterface):
|
||||
@staticmethod
|
||||
def builder(**kwargs) -> LeannBackendBuilderInterface:
|
||||
return HNSWBuilder(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def searcher(index_path: str, **kwargs) -> LeannBackendSearcherInterface:
|
||||
path = Path(index_path)
|
||||
meta_path = path.parent / f"{path.stem}.hnsw.meta.json"
|
||||
if not meta_path.exists():
|
||||
raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.")
|
||||
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(meta.get("embedding_model"))
|
||||
dimensions = model.get_sentence_embedding_dimension()
|
||||
kwargs['dimensions'] = dimensions
|
||||
except ImportError:
|
||||
raise ImportError("sentence-transformers is required to infer embedding dimensions. Please install it.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Could not load SentenceTransformer model to get dimension: {e}")
|
||||
|
||||
return HNSWSearcher(index_path, **kwargs)
|
||||
|
||||
class HNSWBuilder(LeannBackendBuilderInterface):
|
||||
def __init__(self, **kwargs):
|
||||
self.build_params = kwargs
|
||||
|
||||
def build(self, data: np.ndarray, index_path: str, **kwargs):
|
||||
"""Build HNSW index using FAISS"""
|
||||
from . import faiss
|
||||
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
|
||||
index_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
if not data.flags['C_CONTIGUOUS']:
|
||||
data = np.ascontiguousarray(data)
|
||||
|
||||
build_kwargs = {**self.build_params, **kwargs}
|
||||
metric_str = build_kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = get_metric_map().get(metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
|
||||
# HNSW parameters
|
||||
M = build_kwargs.get("M", 32) # Max connections per layer
|
||||
efConstruction = build_kwargs.get("efConstruction", 200) # Size of the dynamic candidate list for construction
|
||||
dim = data.shape[1]
|
||||
|
||||
print(f"INFO: Building HNSW index for {data.shape[0]} vectors with metric {metric_enum}...")
|
||||
|
||||
try:
|
||||
# Create HNSW index
|
||||
if metric_enum == faiss.METRIC_INNER_PRODUCT:
|
||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
||||
else: # L2
|
||||
index = faiss.IndexHNSWFlat(dim, M, metric_enum)
|
||||
|
||||
# Set construction parameters
|
||||
index.hnsw.efConstruction = efConstruction
|
||||
|
||||
# Normalize vectors if using cosine similarity
|
||||
if metric_str == "cosine":
|
||||
faiss.normalize_L2(data)
|
||||
|
||||
# Add vectors to index
|
||||
index.add(data.shape[0], faiss.swig_ptr(data))
|
||||
|
||||
# Save index
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
print(f"✅ HNSW index built successfully at '{index_file}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: HNSW index build failed. Exception: {e}")
|
||||
raise
|
||||
|
||||
class HNSWSearcher(LeannBackendSearcherInterface):
|
||||
def __init__(self, index_path: str, **kwargs):
|
||||
from . import faiss
|
||||
path = Path(index_path)
|
||||
index_dir = path.parent
|
||||
index_prefix = path.stem
|
||||
|
||||
metric_str = kwargs.get("distance_metric", "mips").lower()
|
||||
metric_enum = get_metric_map().get(metric_str)
|
||||
if metric_enum is None:
|
||||
raise ValueError(f"Unsupported distance_metric '{metric_str}'.")
|
||||
|
||||
dimensions = kwargs.get("dimensions")
|
||||
if not dimensions:
|
||||
raise ValueError("Vector dimension not provided to HNSWSearcher.")
|
||||
|
||||
try:
|
||||
# Load FAISS HNSW index
|
||||
index_file = index_dir / f"{index_prefix}.index"
|
||||
if not index_file.exists():
|
||||
raise FileNotFoundError(f"HNSW index file not found at {index_file}")
|
||||
|
||||
self._index = faiss.read_index(str(index_file))
|
||||
self.metric_str = metric_str
|
||||
self.embedding_server_manager = HNSWEmbeddingServerManager()
|
||||
print("✅ HNSW index loaded successfully.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: Failed to load HNSW index. Exception: {e}")
|
||||
raise
|
||||
|
||||
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
|
||||
"""Search using HNSW index with optional recompute functionality"""
|
||||
ef = kwargs.get("ef", 200) # Size of the dynamic candidate list for search
|
||||
|
||||
# Recompute parameters
|
||||
recompute_neighbor_embeddings = kwargs.get("recompute_neighbor_embeddings", False)
|
||||
zmq_port = kwargs.get("zmq_port", 5556)
|
||||
embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2")
|
||||
passages_file = kwargs.get("passages_file", None)
|
||||
|
||||
if recompute_neighbor_embeddings:
|
||||
print(f"INFO: HNSW ZMQ mode enabled - ensuring embedding server is running")
|
||||
|
||||
if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file):
|
||||
print(f"WARNING: Failed to start HNSW embedding server, falling back to standard search")
|
||||
kwargs['recompute_neighbor_embeddings'] = False
|
||||
|
||||
if query.dtype != np.float32:
|
||||
query = query.astype(np.float32)
|
||||
if query.ndim == 1:
|
||||
query = np.expand_dims(query, axis=0)
|
||||
|
||||
# Normalize query if using cosine similarity
|
||||
if self.metric_str == "cosine":
|
||||
faiss.normalize_L2(query)
|
||||
|
||||
try:
|
||||
# Set search parameter
|
||||
self._index.hnsw.efSearch = ef
|
||||
|
||||
if recompute_neighbor_embeddings:
|
||||
# Use custom search with recompute
|
||||
# This would require implementing custom HNSW search logic
|
||||
# For now, we'll fall back to standard search
|
||||
print("WARNING: Recompute functionality for HNSW not yet implemented, using standard search")
|
||||
distances, labels = self._index.search(query, top_k)
|
||||
else:
|
||||
# Standard FAISS search
|
||||
distances, labels = self._index.search(query, top_k)
|
||||
|
||||
return {"labels": labels, "distances": distances}
|
||||
|
||||
except Exception as e:
|
||||
print(f"💥 ERROR: HNSW search failed. Exception: {e}")
|
||||
batch_size = query.shape[0]
|
||||
return {"labels": np.full((batch_size, top_k), -1, dtype=np.int64),
|
||||
"distances": np.full((batch_size, top_k), float('inf'), dtype=np.float32)}
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, 'embedding_server_manager'):
|
||||
self.embedding_server_manager.stop_server()
|
||||
@@ -0,0 +1,583 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
HNSW-specific embedding server with removed config.py dependencies
|
||||
Based on DiskANN embedding server architecture
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import argparse
|
||||
import threading
|
||||
import time
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
import zmq
|
||||
import numpy as np
|
||||
import msgpack
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
def is_similarity_metric():
|
||||
"""
|
||||
Check if the metric type is similarity-based (like inner product).
|
||||
0 = L2 (distance metric), 1 = Inner Product (similarity metric)
|
||||
"""
|
||||
return True # 1 is METRIC_INNER_PRODUCT in FAISS
|
||||
|
||||
# Function for E5-style average pooling
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
|
||||
class SimplePassageLoader:
|
||||
"""
|
||||
Simple passage loader that replaces config.py dependencies
|
||||
"""
|
||||
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
||||
self.passages_data = passages_data or {}
|
||||
|
||||
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
||||
"""Get passage by ID"""
|
||||
str_id = str(passage_id)
|
||||
if str_id in self.passages_data:
|
||||
return {"text": self.passages_data[str_id]}
|
||||
else:
|
||||
# Return empty text for missing passages
|
||||
return {"text": ""}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.passages_data)
|
||||
|
||||
def load_passages_from_file(passages_file: str) -> SimplePassageLoader:
|
||||
"""
|
||||
Load passages from a JSON file
|
||||
Expected format: {"passage_id": "passage_text", ...}
|
||||
"""
|
||||
if not os.path.exists(passages_file):
|
||||
print(f"Warning: Passages file {passages_file} not found. Using empty loader.")
|
||||
return SimplePassageLoader()
|
||||
|
||||
try:
|
||||
with open(passages_file, 'r', encoding='utf-8') as f:
|
||||
passages_data = json.load(f)
|
||||
print(f"Loaded {len(passages_data)} passages from {passages_file}")
|
||||
return SimplePassageLoader(passages_data)
|
||||
except Exception as e:
|
||||
print(f"Error loading passages from {passages_file}: {e}")
|
||||
return SimplePassageLoader()
|
||||
|
||||
def create_hnsw_embedding_server(
|
||||
passages_file: Optional[str] = None,
|
||||
passages_data: Optional[Dict[str, str]] = None,
|
||||
embeddings_file: Optional[str] = None,
|
||||
use_fp16: bool = True,
|
||||
use_int8: bool = False,
|
||||
use_cuda_graphs: bool = False,
|
||||
zmq_port: int = 5555,
|
||||
max_batch_size: int = 128,
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
custom_max_length_param: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||
|
||||
Args:
|
||||
passages_file: Path to JSON file containing passage ID -> text mapping
|
||||
passages_data: Direct passage data dict (alternative to passages_file)
|
||||
embeddings_file: Path to pre-computed embeddings file (optional)
|
||||
use_fp16: Whether to use FP16 precision
|
||||
use_int8: Whether to use INT8 quantization
|
||||
use_cuda_graphs: Whether to use CUDA graphs
|
||||
zmq_port: ZMQ port to bind to
|
||||
max_batch_size: Maximum batch size for processing
|
||||
model_name: Transformer model name
|
||||
custom_max_length_param: Custom max sequence length
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||
|
||||
# Device setup
|
||||
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
|
||||
cuda_available = torch.cuda.is_available()
|
||||
|
||||
print(f"MPS available: {mps_available}")
|
||||
print(f"CUDA available: {cuda_available}")
|
||||
|
||||
if cuda_available:
|
||||
device = torch.device("cuda")
|
||||
print("Using CUDA device")
|
||||
elif mps_available:
|
||||
device = torch.device("mps")
|
||||
print("Using MPS device (Apple Silicon)")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("Using CPU device (no GPU acceleration available)")
|
||||
|
||||
# Load model to the appropriate device
|
||||
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
||||
|
||||
# Check port availability
|
||||
import socket
|
||||
def check_port(port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
return
|
||||
|
||||
# Apply model optimizations (similar to DiskANN version)
|
||||
if use_fp16 and (cuda_available or mps_available):
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
print(f"Using FP16 precision with model: {model_name}")
|
||||
elif use_int8:
|
||||
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
|
||||
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
|
||||
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
||||
model = torch.compile(model)
|
||||
model.eval()
|
||||
print("- Model successfully quantized and compiled")
|
||||
|
||||
# Load passages
|
||||
if passages_data:
|
||||
passages = SimplePassageLoader(passages_data)
|
||||
print(f"Using provided passages data: {len(passages)} passages")
|
||||
elif passages_file:
|
||||
passages = load_passages_from_file(passages_file)
|
||||
else:
|
||||
passages = SimplePassageLoader()
|
||||
print("No passages provided, using empty loader")
|
||||
|
||||
# Load embeddings if provided
|
||||
_embeddings = None
|
||||
if embeddings_file and os.path.exists(embeddings_file):
|
||||
try:
|
||||
with open(embeddings_file, "rb") as f:
|
||||
_embeddings = pickle.load(f)
|
||||
print(f"Loaded embeddings from {embeddings_file}")
|
||||
except Exception as e:
|
||||
print(f"Error loading embeddings: {e}")
|
||||
|
||||
class DeviceTimer:
|
||||
"""Device event-based timer for accurate timing."""
|
||||
def __init__(self, name="", device=device):
|
||||
self.name = name
|
||||
self.device = device
|
||||
self.start_time = 0
|
||||
self.end_time = 0
|
||||
|
||||
if cuda_available:
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
else:
|
||||
self.start_event = None
|
||||
self.end_event = None
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
self.start()
|
||||
yield
|
||||
self.end()
|
||||
|
||||
def start(self):
|
||||
if cuda_available:
|
||||
torch.cuda.synchronize()
|
||||
self.start_event.record()
|
||||
else:
|
||||
if self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.start_time = time.time()
|
||||
|
||||
def end(self):
|
||||
if cuda_available:
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
else:
|
||||
if self.device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
if cuda_available:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
else:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def print_elapsed(self):
|
||||
return # Disabled for now
|
||||
|
||||
def process_batch(texts_batch, ids_batch, missing_ids):
|
||||
"""Process a batch of texts and return embeddings"""
|
||||
_is_e5_model = "e5" in model_name.lower()
|
||||
batch_size = len(texts_batch)
|
||||
|
||||
# E5 model preprocessing
|
||||
if _is_e5_model:
|
||||
processed_texts_batch = [f"passage: {text}" for text in texts_batch]
|
||||
else:
|
||||
processed_texts_batch = texts_batch
|
||||
|
||||
# Set max length
|
||||
if _is_e5_model:
|
||||
current_max_length = custom_max_length_param if custom_max_length_param is not None else 512
|
||||
else:
|
||||
current_max_length = custom_max_length_param if custom_max_length_param is not None else 256
|
||||
|
||||
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
||||
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
||||
embed_timer = DeviceTimer("embedding (batch)", device)
|
||||
pool_timer = DeviceTimer("pooling (batch)", device)
|
||||
norm_timer = DeviceTimer("normalization (batch)", device)
|
||||
|
||||
with tokenize_timer.timing():
|
||||
encoded_batch = tokenizer(
|
||||
processed_texts_batch,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=current_max_length,
|
||||
return_tensors="pt",
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
|
||||
seq_length = encoded_batch["input_ids"].size(1)
|
||||
|
||||
with to_device_timer.timing():
|
||||
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
with embed_timer.timing():
|
||||
out = model(enc["input_ids"], enc["attention_mask"])
|
||||
|
||||
with pool_timer.timing():
|
||||
if not hasattr(out, 'last_hidden_state'):
|
||||
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
||||
pooled_embeddings = out
|
||||
else:
|
||||
print(f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}")
|
||||
hidden_dim = getattr(model.config, 'hidden_size', 384 if _is_e5_model else 768)
|
||||
pooled_embeddings = torch.zeros((batch_size, hidden_dim), device=device, dtype=enc["input_ids"].dtype if hasattr(enc["input_ids"], "dtype") else torch.float32)
|
||||
elif _is_e5_model:
|
||||
pooled_embeddings = e5_average_pool(out.last_hidden_state, enc['attention_mask'])
|
||||
else:
|
||||
hidden_states = out.last_hidden_state
|
||||
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
|
||||
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
||||
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
||||
pooled_embeddings = sum_embeddings / sum_mask
|
||||
|
||||
final_embeddings = pooled_embeddings
|
||||
if _is_e5_model:
|
||||
with norm_timer.timing():
|
||||
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
|
||||
|
||||
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
|
||||
print(f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
|
||||
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}")
|
||||
dim_size = final_embeddings.shape[-1]
|
||||
error_output = torch.zeros((batch_size, dim_size), device='cpu', dtype=torch.float32).numpy()
|
||||
print(f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}")
|
||||
return error_output
|
||||
|
||||
return final_embeddings.cpu().numpy()
|
||||
|
||||
def client_warmup(zmq_port):
|
||||
"""Perform client-side warmup"""
|
||||
time.sleep(2)
|
||||
print(f"Performing client-side warmup with model {model_name}...")
|
||||
sample_ids = ["1", "2", "3", "4", "5"]
|
||||
|
||||
try:
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REQ)
|
||||
socket.connect(f"tcp://localhost:{zmq_port}")
|
||||
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
||||
|
||||
try:
|
||||
ids_to_send = [int(x) for x in sample_ids]
|
||||
except ValueError:
|
||||
ids_to_send = []
|
||||
|
||||
if not ids_to_send:
|
||||
print("Skipping warmup send.")
|
||||
return
|
||||
|
||||
request_payload = [ids_to_send]
|
||||
request_bytes = msgpack.packb(request_payload)
|
||||
|
||||
for i in range(3):
|
||||
print(f"Sending warmup request {i+1}/3 via ZMQ (MessagePack)...")
|
||||
socket.send(request_bytes)
|
||||
response_bytes = socket.recv()
|
||||
|
||||
response_payload = msgpack.unpackb(response_bytes)
|
||||
dimensions = response_payload[0]
|
||||
embeddings_count = dimensions[0] if dimensions and len(dimensions) > 0 else 0
|
||||
print(f"Warmup request {i+1}/3 successful, received {embeddings_count} embeddings")
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Client-side MessagePack ZMQ warmup complete")
|
||||
socket.close()
|
||||
context.term()
|
||||
except Exception as e:
|
||||
print(f"Error during MessagePack ZMQ warmup: {e}")
|
||||
|
||||
def zmq_server_thread():
|
||||
"""ZMQ server thread"""
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
print(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
|
||||
while True:
|
||||
try:
|
||||
message_bytes = socket.recv()
|
||||
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||
|
||||
e2e_start = time.time()
|
||||
lookup_timer = DeviceTimer("text lookup", device)
|
||||
|
||||
try:
|
||||
request_payload = msgpack.unpackb(message_bytes)
|
||||
|
||||
# Handle distance calculation requests
|
||||
if isinstance(request_payload, list) and len(request_payload) == 2 and isinstance(request_payload[0], list) and isinstance(request_payload[1], list):
|
||||
node_ids = request_payload[0]
|
||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||
|
||||
print(f"Request for distance calculation: {len(node_ids)} nodes, query vector dim: {len(query_vector)}")
|
||||
|
||||
# Get embeddings for node IDs
|
||||
texts = []
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
if txtinfo is None or txtinfo["text"] == "":
|
||||
print(f"Warning: Passage with ID {nid} not found")
|
||||
missing_ids.append(nid)
|
||||
txt = ""
|
||||
else:
|
||||
txt = txtinfo["text"]
|
||||
texts.append(txt)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
# Process embeddings in chunks if needed
|
||||
all_node_embeddings = []
|
||||
total_size = len(texts)
|
||||
|
||||
if total_size > max_batch_size:
|
||||
for i in range(0, total_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, total_size)
|
||||
chunk_texts = texts[i:end_idx]
|
||||
chunk_ids = node_ids[i:end_idx]
|
||||
|
||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
||||
all_node_embeddings.append(embeddings_chunk)
|
||||
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
node_embeddings = np.vstack(all_node_embeddings)
|
||||
else:
|
||||
node_embeddings = process_batch(texts, node_ids, missing_ids)
|
||||
|
||||
# Calculate distances
|
||||
query_tensor = torch.tensor(query_vector, device=device).float()
|
||||
node_embeddings_tensor = torch.tensor(node_embeddings, device=device).float()
|
||||
|
||||
calc_timer = DeviceTimer("distance calculation", device)
|
||||
with calc_timer.timing():
|
||||
with torch.no_grad():
|
||||
if is_similarity_metric():
|
||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
|
||||
query_np = query_tensor.cpu().numpy()
|
||||
distances = -np.dot(node_embeddings_np, query_np)
|
||||
else:
|
||||
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
|
||||
query_np = query_tensor.cpu().numpy().astype(np.float32)
|
||||
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
|
||||
calc_timer.print_elapsed()
|
||||
|
||||
try:
|
||||
response_payload = distances.flatten().tolist()
|
||||
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
||||
print(f"Sending distance response with {len(distances)} distances")
|
||||
except Exception as pack_error:
|
||||
print(f"Error packing MessagePack distance response: {pack_error}")
|
||||
response_bytes = msgpack.packb([[]])
|
||||
|
||||
socket.send(response_bytes)
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
continue
|
||||
|
||||
# Standard embedding request
|
||||
if not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list):
|
||||
print(f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}")
|
||||
socket.send(msgpack.packb([[], []]))
|
||||
continue
|
||||
|
||||
node_ids = request_payload[0]
|
||||
print(f"Request for {len(node_ids)} node embeddings")
|
||||
|
||||
except Exception as unpack_error:
|
||||
print(f"Error unpacking MessagePack request: {unpack_error}")
|
||||
socket.send(msgpack.packb([[], []]))
|
||||
continue
|
||||
|
||||
# Look up texts by node IDs
|
||||
texts = []
|
||||
missing_ids = []
|
||||
with lookup_timer.timing():
|
||||
for nid in node_ids:
|
||||
txtinfo = passages[nid]
|
||||
if txtinfo is None or txtinfo["text"] == "":
|
||||
print(f"Warning: Passage with ID {nid} not found")
|
||||
missing_ids.append(nid)
|
||||
txt = ""
|
||||
else:
|
||||
txt = txtinfo["text"]
|
||||
texts.append(txt)
|
||||
lookup_timer.print_elapsed()
|
||||
|
||||
if missing_ids:
|
||||
print(f"Missing passages for IDs: {missing_ids}")
|
||||
|
||||
# Process in chunks
|
||||
total_size = len(texts)
|
||||
print(f"Total batch size: {total_size}, max_batch_size: {max_batch_size}")
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
if total_size > max_batch_size:
|
||||
print(f"Splitting batch of size {total_size} into chunks of {max_batch_size}")
|
||||
for i in range(0, total_size, max_batch_size):
|
||||
end_idx = min(i + max_batch_size, total_size)
|
||||
print(f"Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
|
||||
|
||||
chunk_texts = texts[i:end_idx]
|
||||
chunk_ids = node_ids[i:end_idx]
|
||||
|
||||
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
|
||||
all_embeddings.append(embeddings_chunk)
|
||||
|
||||
if cuda_available:
|
||||
torch.cuda.empty_cache()
|
||||
elif device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
hidden = np.vstack(all_embeddings)
|
||||
print(f"Combined embeddings shape: {hidden.shape}")
|
||||
else:
|
||||
hidden = process_batch(texts, node_ids, missing_ids)
|
||||
|
||||
# Serialization and response
|
||||
ser_start = time.time()
|
||||
|
||||
print(f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}")
|
||||
if np.isnan(hidden).any() or np.isinf(hidden).any():
|
||||
print(f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
|
||||
f"Requested IDs (sample): {node_ids[:5]}...{RESET}")
|
||||
assert False
|
||||
|
||||
try:
|
||||
hidden_contiguous_f32 = np.ascontiguousarray(hidden, dtype=np.float32)
|
||||
response_payload = [
|
||||
list(hidden_contiguous_f32.shape),
|
||||
hidden_contiguous_f32.flatten().tolist()
|
||||
]
|
||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||
except Exception as pack_error:
|
||||
print(f"Error packing MessagePack response: {pack_error}")
|
||||
response_bytes = msgpack.packb([[], []])
|
||||
|
||||
socket.send(response_bytes)
|
||||
ser_end = time.time()
|
||||
|
||||
print(f"Serialize time: {ser_end - ser_start:.6f} seconds")
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
e2e_end = time.time()
|
||||
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
||||
|
||||
except zmq.Again:
|
||||
print("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error in ZMQ server loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
try:
|
||||
socket.send(msgpack.packb([[], []]))
|
||||
except:
|
||||
pass
|
||||
|
||||
# Start warmup and server threads
|
||||
if len(passages) > 0:
|
||||
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
||||
warmup_thread.daemon = True
|
||||
warmup_thread.start()
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("HNSW Server shutting down...")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
||||
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
||||
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
|
||||
parser.add_argument("--embeddings-file", type=str, help="Pickle file containing pre-computed embeddings")
|
||||
parser.add_argument("--use-fp16", action="store_true", default=False)
|
||||
parser.add_argument("--use-int8", action="store_true", default=False)
|
||||
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
||||
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
|
||||
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model name")
|
||||
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create and start the HNSW embedding server
|
||||
create_hnsw_embedding_server(
|
||||
passages_file=args.passages_file,
|
||||
embeddings_file=args.embeddings_file,
|
||||
use_fp16=args.use_fp16,
|
||||
use_int8=args.use_int8,
|
||||
use_cuda_graphs=args.use_cuda_graphs,
|
||||
zmq_port=args.zmq_port,
|
||||
max_batch_size=args.max_batch_size,
|
||||
model_name=args.model_name,
|
||||
custom_max_length_param=args.custom_max_length,
|
||||
)
|
||||
Reference in New Issue
Block a user