1019 lines
41 KiB
Python
1019 lines
41 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
HNSW-specific embedding server with removed config.py dependencies
|
|
Based on DiskANN embedding server architecture
|
|
"""
|
|
|
|
import pickle
|
|
import argparse
|
|
import threading
|
|
import time
|
|
from transformers import AutoTokenizer, AutoModel
|
|
import os
|
|
from contextlib import contextmanager
|
|
import zmq
|
|
import numpy as np
|
|
import msgpack
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, Union
|
|
import sys
|
|
|
|
RED = "\033[91m"
|
|
RESET = "\033[0m"
|
|
|
|
|
|
def is_similarity_metric():
|
|
"""
|
|
Check if the metric type is similarity-based (like inner product).
|
|
0 = L2 (distance metric), 1 = Inner Product (similarity metric)
|
|
"""
|
|
return True # 1 is METRIC_INNER_PRODUCT in FAISS
|
|
|
|
|
|
# Function for E5-style average pooling
|
|
import torch
|
|
from torch import Tensor
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
|
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
|
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
|
|
|
|
|
class SimplePassageLoader:
|
|
"""
|
|
Simple passage loader that replaces config.py dependencies
|
|
"""
|
|
|
|
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
|
|
self.passages_data = passages_data or {}
|
|
self._meta_path = ""
|
|
|
|
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
|
"""Get passage by ID"""
|
|
str_id = str(passage_id)
|
|
if str_id in self.passages_data:
|
|
return {"text": self.passages_data[str_id]}
|
|
else:
|
|
# Return empty text for missing passages
|
|
return {"text": ""}
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.passages_data)
|
|
|
|
def keys(self):
|
|
return self.passages_data.keys()
|
|
|
|
|
|
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
|
|
"""
|
|
Load passages using metadata file with PassageManager for lazy loading
|
|
"""
|
|
# Load metadata to get passage sources
|
|
with open(meta_file, "r") as f:
|
|
meta = json.load(f)
|
|
|
|
# Import PassageManager dynamically to avoid circular imports
|
|
# Find the leann package directory relative to this file
|
|
current_dir = Path(__file__).parent
|
|
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
|
|
sys.path.insert(0, str(leann_core_path))
|
|
|
|
try:
|
|
from leann.api import PassageManager
|
|
|
|
passage_manager = PassageManager(meta["passage_sources"])
|
|
finally:
|
|
sys.path.pop(0)
|
|
|
|
# Load label map
|
|
passages_dir = Path(meta_file).parent
|
|
label_map_file = passages_dir / "leann.labels.map"
|
|
|
|
if label_map_file.exists():
|
|
import pickle
|
|
|
|
with open(label_map_file, "rb") as f:
|
|
label_map = pickle.load(f)
|
|
print(f"Loaded label map with {len(label_map)} entries")
|
|
else:
|
|
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
|
|
|
|
print(f"Initialized lazy passage loading for {len(label_map)} passages")
|
|
|
|
class LazyPassageLoader(SimplePassageLoader):
|
|
def __init__(self, passage_manager, label_map):
|
|
self.passage_manager = passage_manager
|
|
self.label_map = label_map
|
|
# Initialize parent with empty data
|
|
super().__init__({})
|
|
|
|
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
|
|
"""Get passage by ID with lazy loading"""
|
|
try:
|
|
int_id = int(passage_id)
|
|
if int_id in self.label_map:
|
|
string_id = self.label_map[int_id]
|
|
passage_data = self.passage_manager.get_passage(string_id)
|
|
if passage_data and passage_data.get("text"):
|
|
return {"text": passage_data["text"]}
|
|
else:
|
|
print(f"DEBUG: Empty text for ID {int_id} -> {string_id}")
|
|
return {"text": ""}
|
|
else:
|
|
print(f"DEBUG: ID {int_id} not found in label_map")
|
|
return {"text": ""}
|
|
except Exception as e:
|
|
print(f"DEBUG: Exception getting passage {passage_id}: {e}")
|
|
return {"text": ""}
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.label_map)
|
|
|
|
def keys(self):
|
|
return self.label_map.keys()
|
|
|
|
return LazyPassageLoader(passage_manager, label_map)
|
|
|
|
|
|
def create_hnsw_embedding_server(
|
|
passages_file: Optional[str] = None,
|
|
passages_data: Optional[Dict[str, str]] = None,
|
|
embeddings_file: Optional[str] = None,
|
|
use_fp16: bool = True,
|
|
use_int8: bool = False,
|
|
use_cuda_graphs: bool = False,
|
|
zmq_port: int = 5555,
|
|
max_batch_size: int = 128,
|
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
|
custom_max_length_param: Optional[int] = None,
|
|
distance_metric: str = "mips",
|
|
use_mlx: bool = False,
|
|
enable_warmup: bool = False,
|
|
):
|
|
"""
|
|
Create and start a ZMQ-based embedding server for HNSW backend.
|
|
|
|
Args:
|
|
passages_file: Path to JSON file containing passage ID -> text mapping
|
|
passages_data: Direct passage data dict (alternative to passages_file)
|
|
embeddings_file: Path to pre-computed embeddings file (optional)
|
|
use_fp16: Whether to use FP16 precision
|
|
use_int8: Whether to use INT8 quantization
|
|
use_cuda_graphs: Whether to use CUDA graphs
|
|
zmq_port: ZMQ port to bind to
|
|
max_batch_size: Maximum batch size for processing
|
|
model_name: Transformer model name
|
|
custom_max_length_param: Custom max sequence length
|
|
distance_metric: The distance metric to use
|
|
enable_warmup: Whether to perform warmup requests on server start
|
|
"""
|
|
if not use_mlx:
|
|
print(f"Loading tokenizer for {model_name}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
print(f"Tokenizer loaded successfully!")
|
|
else:
|
|
print("Using MLX mode - tokenizer will be loaded separately")
|
|
tokenizer = None
|
|
|
|
# Device setup
|
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
cuda_available = torch.cuda.is_available()
|
|
|
|
print(f"MPS available: {mps_available}")
|
|
print(f"CUDA available: {cuda_available}")
|
|
|
|
if cuda_available:
|
|
device = torch.device("cuda")
|
|
print("Using CUDA device")
|
|
elif mps_available:
|
|
device = torch.device("mps")
|
|
print("Using MPS device (Apple Silicon)")
|
|
else:
|
|
device = torch.device("cpu")
|
|
print("Using CPU device (no GPU acceleration available)")
|
|
|
|
# Load model to the appropriate device
|
|
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
|
print(f"Loading model {model_name}... (this may take a while if downloading)")
|
|
|
|
if use_mlx:
|
|
# For MLX models, we need to use the MLX embedding computation
|
|
print("MLX model detected - using MLX backend for embeddings")
|
|
model = None # We'll handle MLX separately
|
|
tokenizer = None
|
|
else:
|
|
# Use standard transformers for non-MLX models
|
|
model = AutoModel.from_pretrained(model_name).to(device).eval()
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
print(f"Model {model_name} loaded successfully!")
|
|
|
|
# Check port availability
|
|
import socket
|
|
|
|
def check_port(port):
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
return s.connect_ex(("localhost", port)) == 0
|
|
|
|
if check_port(zmq_port):
|
|
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
|
return
|
|
|
|
# Apply model optimizations (similar to DiskANN version)
|
|
if use_fp16 and (cuda_available or mps_available):
|
|
model = model.half()
|
|
model = torch.compile(model)
|
|
print(f"Using FP16 precision with model: {model_name}")
|
|
elif use_int8:
|
|
print(
|
|
"- Using TorchAO for Int8 dynamic activation and Int8 weight quantization"
|
|
)
|
|
from torchao.quantization import (
|
|
quantize_,
|
|
Int8DynamicActivationInt8WeightConfig,
|
|
)
|
|
|
|
quantize_(model, Int8DynamicActivationInt8WeightConfig())
|
|
model = torch.compile(model)
|
|
model.eval()
|
|
print("- Model successfully quantized and compiled")
|
|
|
|
# Load passages
|
|
if passages_data:
|
|
passages = SimplePassageLoader(passages_data)
|
|
print(f"Using provided passages data: {len(passages)} passages")
|
|
elif passages_file:
|
|
# Check if it's a metadata file or a single passages file
|
|
if passages_file.endswith(".meta.json"):
|
|
passages = load_passages_from_metadata(passages_file)
|
|
# Store the meta path for future reference
|
|
passages._meta_path = passages_file
|
|
else:
|
|
# Try to find metadata file in same directory
|
|
passages_dir = Path(passages_file).parent
|
|
meta_files = list(passages_dir.glob("*.meta.json"))
|
|
if meta_files:
|
|
print(f"Found metadata file: {meta_files[0]}, using lazy loading")
|
|
passages = load_passages_from_metadata(str(meta_files[0]))
|
|
else:
|
|
# Fallback to original single file loading (will cause warnings)
|
|
print(
|
|
"WARNING: No metadata file found, using single file loading (may cause missing passage warnings)"
|
|
)
|
|
passages = (
|
|
SimplePassageLoader()
|
|
) # Use empty loader to avoid massive warnings
|
|
else:
|
|
passages = SimplePassageLoader()
|
|
print("No passages provided, using empty loader")
|
|
|
|
# Load embeddings if provided
|
|
_embeddings = None
|
|
if embeddings_file and os.path.exists(embeddings_file):
|
|
try:
|
|
with open(embeddings_file, "rb") as f:
|
|
_embeddings = pickle.load(f)
|
|
print(f"Loaded embeddings from {embeddings_file}")
|
|
except Exception as e:
|
|
print(f"Error loading embeddings: {e}")
|
|
|
|
class DeviceTimer:
|
|
"""Device event-based timer for accurate timing."""
|
|
|
|
def __init__(self, name="", device=device):
|
|
self.name = name
|
|
self.device = device
|
|
self.start_time = 0
|
|
self.end_time = 0
|
|
|
|
if cuda_available:
|
|
self.start_event = torch.cuda.Event(enable_timing=True)
|
|
self.end_event = torch.cuda.Event(enable_timing=True)
|
|
else:
|
|
self.start_event = None
|
|
self.end_event = None
|
|
|
|
@contextmanager
|
|
def timing(self):
|
|
self.start()
|
|
yield
|
|
self.end()
|
|
|
|
def start(self):
|
|
if cuda_available:
|
|
torch.cuda.synchronize()
|
|
self.start_event.record()
|
|
else:
|
|
if self.device.type == "mps":
|
|
torch.mps.synchronize()
|
|
self.start_time = time.time()
|
|
|
|
def end(self):
|
|
if cuda_available:
|
|
self.end_event.record()
|
|
torch.cuda.synchronize()
|
|
else:
|
|
if self.device.type == "mps":
|
|
torch.mps.synchronize()
|
|
self.end_time = time.time()
|
|
|
|
def elapsed_time(self):
|
|
if cuda_available:
|
|
return self.start_event.elapsed_time(self.end_event) / 1000.0
|
|
else:
|
|
return self.end_time - self.start_time
|
|
|
|
def print_elapsed(self):
|
|
return # Disabled for now
|
|
|
|
def _process_batch_mlx(texts_batch, ids_batch, missing_ids):
|
|
"""Process a batch of texts using MLX backend"""
|
|
try:
|
|
# Import MLX embedding computation from main API
|
|
from leann.api import compute_embeddings
|
|
|
|
# Compute embeddings using MLX
|
|
embeddings = compute_embeddings(texts_batch, model_name, use_mlx=True)
|
|
|
|
print(
|
|
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: MLX embeddings computed for {len(texts_batch)} texts"
|
|
)
|
|
print(
|
|
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: Embedding shape: {embeddings.shape}"
|
|
)
|
|
|
|
return embeddings
|
|
|
|
except Exception as e:
|
|
print(
|
|
f"[leann_backend_hnsw.hnsw_embedding_server LOG]: ERROR in MLX processing: {e}"
|
|
)
|
|
raise
|
|
|
|
def process_batch(texts_batch, ids_batch, missing_ids):
|
|
"""Process a batch of texts and return embeddings"""
|
|
|
|
# Handle MLX models separately
|
|
if use_mlx:
|
|
return _process_batch_mlx(texts_batch, ids_batch, missing_ids)
|
|
|
|
_is_e5_model = "e5" in model_name.lower()
|
|
_is_bge_model = "bge" in model_name.lower()
|
|
batch_size = len(texts_batch)
|
|
|
|
# Allow empty texts to pass through (remove validation)
|
|
|
|
# E5 model preprocessing
|
|
if _is_e5_model:
|
|
processed_texts_batch = [f"passage: {text}" for text in texts_batch]
|
|
else:
|
|
processed_texts_batch = texts_batch
|
|
|
|
# Set max length
|
|
if _is_e5_model:
|
|
current_max_length = (
|
|
custom_max_length_param if custom_max_length_param is not None else 512
|
|
)
|
|
else:
|
|
current_max_length = (
|
|
custom_max_length_param if custom_max_length_param is not None else 256
|
|
)
|
|
|
|
tokenize_timer = DeviceTimer("tokenization (batch)", device)
|
|
to_device_timer = DeviceTimer("transfer to device (batch)", device)
|
|
embed_timer = DeviceTimer("embedding (batch)", device)
|
|
pool_timer = DeviceTimer("pooling (batch)", device)
|
|
norm_timer = DeviceTimer("normalization (batch)", device)
|
|
|
|
with tokenize_timer.timing():
|
|
encoded_batch = tokenizer(
|
|
processed_texts_batch,
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=current_max_length,
|
|
return_tensors="pt",
|
|
return_token_type_ids=False,
|
|
)
|
|
|
|
seq_length = encoded_batch["input_ids"].size(1)
|
|
|
|
with to_device_timer.timing():
|
|
enc = {k: v.to(device) for k, v in encoded_batch.items()}
|
|
|
|
with torch.no_grad():
|
|
with embed_timer.timing():
|
|
out = model(enc["input_ids"], enc["attention_mask"])
|
|
|
|
with pool_timer.timing():
|
|
if _is_bge_model:
|
|
pooled_embeddings = out.last_hidden_state[:, 0]
|
|
elif not hasattr(out, "last_hidden_state"):
|
|
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
|
|
pooled_embeddings = out
|
|
else:
|
|
print(
|
|
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
|
|
)
|
|
hidden_dim = getattr(
|
|
model.config, "hidden_size", 384 if _is_e5_model else 768
|
|
)
|
|
pooled_embeddings = torch.zeros(
|
|
(batch_size, hidden_dim),
|
|
device=device,
|
|
dtype=enc["input_ids"].dtype
|
|
if hasattr(enc["input_ids"], "dtype")
|
|
else torch.float32,
|
|
)
|
|
elif _is_e5_model:
|
|
pooled_embeddings = e5_average_pool(
|
|
out.last_hidden_state, enc["attention_mask"]
|
|
)
|
|
else:
|
|
hidden_states = out.last_hidden_state
|
|
mask_expanded = (
|
|
enc["attention_mask"]
|
|
.unsqueeze(-1)
|
|
.expand(hidden_states.size())
|
|
.float()
|
|
)
|
|
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
|
|
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
|
|
pooled_embeddings = sum_embeddings / sum_mask
|
|
|
|
final_embeddings = pooled_embeddings
|
|
if _is_e5_model or _is_bge_model:
|
|
with norm_timer.timing():
|
|
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
|
|
|
|
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
|
|
print(
|
|
f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
|
|
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}"
|
|
)
|
|
dim_size = final_embeddings.shape[-1]
|
|
error_output = torch.zeros(
|
|
(batch_size, dim_size), device="cpu", dtype=torch.float32
|
|
).numpy()
|
|
print(
|
|
f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}"
|
|
)
|
|
return error_output
|
|
|
|
return final_embeddings.cpu().numpy()
|
|
|
|
def client_warmup(zmq_port):
|
|
"""Perform client-side warmup"""
|
|
time.sleep(2)
|
|
print(f"Performing client-side warmup with model {model_name}...")
|
|
|
|
# Get actual passage IDs from the loaded passages
|
|
sample_ids = []
|
|
if hasattr(passages, 'keys') and len(passages) > 0:
|
|
available_ids = list(passages.keys())
|
|
# Take up to 5 actual IDs, but at least 1
|
|
sample_ids = available_ids[:min(5, len(available_ids))]
|
|
print(f"Using actual passage IDs for warmup: {sample_ids}")
|
|
else:
|
|
print("No passages available for warmup, skipping warmup...")
|
|
return
|
|
|
|
try:
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.REQ)
|
|
socket.connect(f"tcp://localhost:{zmq_port}")
|
|
socket.setsockopt(zmq.RCVTIMEO, 30000)
|
|
socket.setsockopt(zmq.SNDTIMEO, 30000)
|
|
|
|
try:
|
|
ids_to_send = [int(x) for x in sample_ids]
|
|
except ValueError:
|
|
print("Warning: Could not convert sample IDs to integers, skipping warmup")
|
|
return
|
|
|
|
if not ids_to_send:
|
|
print("Skipping warmup send.")
|
|
return
|
|
|
|
request_payload = [ids_to_send]
|
|
request_bytes = msgpack.packb(request_payload)
|
|
|
|
for i in range(3):
|
|
print(f"Sending warmup request {i + 1}/3 via ZMQ (MessagePack)...")
|
|
socket.send(request_bytes)
|
|
response_bytes = socket.recv()
|
|
|
|
response_payload = msgpack.unpackb(response_bytes)
|
|
dimensions = response_payload[0]
|
|
embeddings_count = (
|
|
dimensions[0] if dimensions and len(dimensions) > 0 else 0
|
|
)
|
|
print(
|
|
f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings"
|
|
)
|
|
time.sleep(0.1)
|
|
|
|
print("Client-side MessagePack ZMQ warmup complete")
|
|
socket.close()
|
|
context.term()
|
|
except Exception as e:
|
|
print(f"Error during MessagePack ZMQ warmup: {e}")
|
|
|
|
def zmq_server_thread():
|
|
"""ZMQ server thread"""
|
|
nonlocal passages, model, tokenizer, model_name
|
|
context = zmq.Context()
|
|
socket = context.socket(zmq.REP)
|
|
socket.bind(f"tcp://*:{zmq_port}")
|
|
print(f"HNSW ZMQ server listening on port {zmq_port}")
|
|
|
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
|
|
|
while True:
|
|
try:
|
|
message_bytes = socket.recv()
|
|
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
|
|
|
e2e_start = time.time()
|
|
lookup_timer = DeviceTimer("text lookup", device)
|
|
|
|
try:
|
|
request_payload = msgpack.unpackb(message_bytes)
|
|
if isinstance(request_payload, list):
|
|
print(f"DEBUG: request_payload length: {len(request_payload)}")
|
|
for i, item in enumerate(request_payload):
|
|
print(
|
|
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
|
|
)
|
|
|
|
# Handle control messages for meta path and model management
|
|
if isinstance(request_payload, list) and len(request_payload) >= 1:
|
|
if request_payload[0] == "__QUERY_META_PATH__":
|
|
# Return the current meta path being used by the server
|
|
current_meta_path = (
|
|
getattr(passages, "_meta_path", "")
|
|
if hasattr(passages, "_meta_path")
|
|
else ""
|
|
)
|
|
response = [current_meta_path]
|
|
socket.send(msgpack.packb(response))
|
|
continue
|
|
|
|
elif (
|
|
request_payload[0] == "__UPDATE_META_PATH__"
|
|
and len(request_payload) >= 2
|
|
):
|
|
# Update the server's meta path and reload passages
|
|
new_meta_path = request_payload[1]
|
|
try:
|
|
print(
|
|
f"INFO: Updating server meta path to: {new_meta_path}"
|
|
)
|
|
# Reload passages from the new meta file
|
|
passages = load_passages_from_metadata(new_meta_path)
|
|
# Store the meta path for future queries
|
|
passages._meta_path = new_meta_path
|
|
response = ["SUCCESS"]
|
|
print(
|
|
f"INFO: Successfully updated meta path and reloaded {len(passages)} passages"
|
|
)
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to update meta path: {e}")
|
|
response = ["FAILED", str(e)]
|
|
socket.send(msgpack.packb(response))
|
|
continue
|
|
|
|
elif request_payload[0] == "__QUERY_MODEL__":
|
|
# Return the current model being used by the server
|
|
response = [model_name]
|
|
socket.send(msgpack.packb(response))
|
|
continue
|
|
|
|
elif (
|
|
request_payload[0] == "__UPDATE_MODEL__"
|
|
and len(request_payload) >= 2
|
|
):
|
|
# Update the server's embedding model
|
|
new_model_name = request_payload[1]
|
|
try:
|
|
print(
|
|
f"INFO: Updating server model from {model_name} to: {new_model_name}"
|
|
)
|
|
|
|
# Clean up old model to free memory
|
|
print("INFO: Releasing old model from memory...")
|
|
old_model = model
|
|
old_tokenizer = tokenizer
|
|
|
|
# Load new tokenizer first
|
|
print(f"Loading new tokenizer for {new_model_name}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
new_model_name, use_fast=True
|
|
)
|
|
|
|
# Load new model
|
|
print(f"Loading new model {new_model_name}...")
|
|
model = AutoModel.from_pretrained(new_model_name)
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
# Now safely delete old model after new one is loaded
|
|
del old_model
|
|
del old_tokenizer
|
|
|
|
# Clear GPU cache if available
|
|
if device.type == "cuda":
|
|
torch.cuda.empty_cache()
|
|
print("INFO: Cleared CUDA cache")
|
|
elif device.type == "mps":
|
|
torch.mps.empty_cache()
|
|
print("INFO: Cleared MPS cache")
|
|
|
|
# Update model name
|
|
model_name = new_model_name
|
|
|
|
# Force garbage collection
|
|
import gc
|
|
|
|
gc.collect()
|
|
print("INFO: Memory cleanup completed")
|
|
|
|
response = ["SUCCESS"]
|
|
print(
|
|
f"INFO: Successfully updated model to: {new_model_name}"
|
|
)
|
|
except Exception as e:
|
|
print(f"ERROR: Failed to update model: {e}")
|
|
response = ["FAILED", str(e)]
|
|
socket.send(msgpack.packb(response))
|
|
continue
|
|
|
|
# Handle distance calculation requests
|
|
if (
|
|
isinstance(request_payload, list)
|
|
and len(request_payload) == 2
|
|
and isinstance(request_payload[0], list)
|
|
and isinstance(request_payload[1], list)
|
|
):
|
|
node_ids = request_payload[0]
|
|
query_vector = np.array(request_payload[1], dtype=np.float32)
|
|
|
|
print("DEBUG: Distance calculation request received")
|
|
print(f" Node IDs: {node_ids}")
|
|
print(f" Query vector dim: {len(query_vector)}")
|
|
print(f" Passages loaded: {len(passages)}")
|
|
|
|
# Get embeddings for node IDs
|
|
texts = []
|
|
missing_ids = []
|
|
with lookup_timer.timing():
|
|
for nid in node_ids:
|
|
print(f"DEBUG: Looking up passage ID {nid}")
|
|
try:
|
|
txtinfo = passages[nid]
|
|
if txtinfo is None:
|
|
print(
|
|
f"ERROR: Passage with ID {nid} returned None"
|
|
)
|
|
print(f"ERROR: txtinfo: {txtinfo}")
|
|
raise RuntimeError(
|
|
f"FATAL: Passage with ID {nid} returned None"
|
|
)
|
|
txt = txtinfo[
|
|
"text"
|
|
] # Allow empty text to pass through
|
|
print(
|
|
f"DEBUG: Found text for ID {nid}, length: {len(txt)}"
|
|
)
|
|
texts.append(txt)
|
|
except KeyError:
|
|
print(
|
|
f"ERROR: Passage ID {nid} not found in passages dict"
|
|
)
|
|
print(
|
|
f"ERROR: Available passage IDs: {list(passages.keys())}..."
|
|
)
|
|
raise RuntimeError(
|
|
f"FATAL: Passage with ID {nid} not found"
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"ERROR: Exception looking up passage ID {nid}: {e}"
|
|
)
|
|
raise
|
|
lookup_timer.print_elapsed()
|
|
|
|
# Process embeddings in chunks if needed
|
|
all_node_embeddings = []
|
|
total_size = len(texts)
|
|
|
|
if total_size > max_batch_size:
|
|
for i in range(0, total_size, max_batch_size):
|
|
end_idx = min(i + max_batch_size, total_size)
|
|
chunk_texts = texts[i:end_idx]
|
|
chunk_ids = node_ids[i:end_idx]
|
|
|
|
embeddings_chunk = process_batch(
|
|
chunk_texts, chunk_ids, missing_ids
|
|
)
|
|
all_node_embeddings.append(embeddings_chunk)
|
|
|
|
if cuda_available:
|
|
torch.cuda.empty_cache()
|
|
elif device.type == "mps":
|
|
torch.mps.empty_cache()
|
|
|
|
node_embeddings = np.vstack(all_node_embeddings)
|
|
else:
|
|
node_embeddings = process_batch(
|
|
texts, node_ids, missing_ids
|
|
)
|
|
|
|
# Calculate distances
|
|
query_tensor = torch.tensor(query_vector, device=device).float()
|
|
node_embeddings_tensor = torch.tensor(
|
|
node_embeddings, device=device
|
|
).float()
|
|
|
|
calc_timer = DeviceTimer("distance calculation", device)
|
|
with calc_timer.timing():
|
|
with torch.no_grad():
|
|
if distance_metric == "l2":
|
|
node_embeddings_np = (
|
|
node_embeddings_tensor.cpu()
|
|
.numpy()
|
|
.astype(np.float32)
|
|
)
|
|
query_np = (
|
|
query_tensor.cpu().numpy().astype(np.float32)
|
|
)
|
|
distances = np.sum(
|
|
np.square(
|
|
node_embeddings_np - query_np.reshape(1, -1)
|
|
),
|
|
axis=1,
|
|
)
|
|
else: # mips or cosine
|
|
node_embeddings_np = (
|
|
node_embeddings_tensor.cpu().numpy()
|
|
)
|
|
query_np = query_tensor.cpu().numpy()
|
|
distances = -np.dot(node_embeddings_np, query_np)
|
|
calc_timer.print_elapsed()
|
|
|
|
try:
|
|
response_payload = distances.flatten().tolist()
|
|
response_bytes = msgpack.packb(
|
|
[response_payload], use_single_float=True
|
|
)
|
|
print(
|
|
f"Sending distance response with {len(distances)} distances"
|
|
)
|
|
except Exception as pack_error:
|
|
print(
|
|
f"ERROR: Error packing MessagePack distance response: {pack_error}"
|
|
)
|
|
print(f"ERROR: distances shape: {distances.shape}")
|
|
print(f"ERROR: distances dtype: {distances.dtype}")
|
|
print(f"ERROR: distances content: {distances}")
|
|
print(f"ERROR: node_ids: {node_ids}")
|
|
print(f"ERROR: query_vector shape: {query_vector.shape}")
|
|
# Still return empty for now but with full error info
|
|
response_bytes = msgpack.packb([[]])
|
|
|
|
socket.send(response_bytes)
|
|
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
elif device.type == "mps":
|
|
torch.mps.synchronize()
|
|
e2e_end = time.time()
|
|
print(
|
|
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds"
|
|
)
|
|
continue
|
|
|
|
# Standard embedding request
|
|
if (
|
|
not isinstance(request_payload, list)
|
|
or len(request_payload) != 1
|
|
or not isinstance(request_payload[0], list)
|
|
):
|
|
print(
|
|
f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}"
|
|
)
|
|
socket.send(msgpack.packb([[], []]))
|
|
continue
|
|
|
|
node_ids = request_payload[0]
|
|
print(f"Request for {len(node_ids)} node embeddings")
|
|
|
|
except Exception as unpack_error:
|
|
print(f"Error unpacking MessagePack request: {unpack_error}")
|
|
socket.send(msgpack.packb([[], []]))
|
|
continue
|
|
|
|
# Look up texts by node IDs
|
|
texts = []
|
|
missing_ids = []
|
|
with lookup_timer.timing():
|
|
for nid in node_ids:
|
|
try:
|
|
txtinfo = passages[nid]
|
|
if txtinfo is None or txtinfo["text"] == "":
|
|
raise RuntimeError(
|
|
f"FATAL: Passage with ID {nid} not found - failing fast"
|
|
)
|
|
else:
|
|
txt = txtinfo["text"]
|
|
except (KeyError, IndexError):
|
|
raise RuntimeError(
|
|
f"FATAL: Passage with ID {nid} not found - failing fast"
|
|
)
|
|
texts.append(txt)
|
|
lookup_timer.print_elapsed()
|
|
|
|
if missing_ids:
|
|
print(f"Missing passages for IDs: {missing_ids}")
|
|
|
|
# Process in chunks
|
|
total_size = len(texts)
|
|
print(
|
|
f"Total batch size: {total_size}, max_batch_size: {max_batch_size}"
|
|
)
|
|
|
|
all_embeddings = []
|
|
|
|
if total_size > max_batch_size:
|
|
print(
|
|
f"Splitting batch of size {total_size} into chunks of {max_batch_size}"
|
|
)
|
|
for i in range(0, total_size, max_batch_size):
|
|
end_idx = min(i + max_batch_size, total_size)
|
|
print(
|
|
f"Processing chunk {i // max_batch_size + 1}/{(total_size + max_batch_size - 1) // max_batch_size}: items {i} to {end_idx - 1}"
|
|
)
|
|
|
|
chunk_texts = texts[i:end_idx]
|
|
chunk_ids = node_ids[i:end_idx]
|
|
|
|
embeddings_chunk = process_batch(
|
|
chunk_texts, chunk_ids, missing_ids
|
|
)
|
|
all_embeddings.append(embeddings_chunk)
|
|
|
|
if cuda_available:
|
|
torch.cuda.empty_cache()
|
|
elif device.type == "mps":
|
|
torch.mps.empty_cache()
|
|
|
|
hidden = np.vstack(all_embeddings)
|
|
print(f"Combined embeddings shape: {hidden.shape}")
|
|
else:
|
|
hidden = process_batch(texts, node_ids, missing_ids)
|
|
|
|
# Serialization and response
|
|
ser_start = time.time()
|
|
|
|
print(
|
|
f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}"
|
|
)
|
|
if np.isnan(hidden).any() or np.isinf(hidden).any():
|
|
print(
|
|
f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
|
|
f"Requested IDs (sample): {node_ids[:5]}...{RESET}"
|
|
)
|
|
assert False
|
|
|
|
try:
|
|
hidden_contiguous_f32 = np.ascontiguousarray(
|
|
hidden, dtype=np.float32
|
|
)
|
|
response_payload = [
|
|
list(hidden_contiguous_f32.shape),
|
|
hidden_contiguous_f32.flatten().tolist(),
|
|
]
|
|
response_bytes = msgpack.packb(
|
|
response_payload, use_single_float=True
|
|
)
|
|
except Exception as pack_error:
|
|
print(f"Error packing MessagePack response: {pack_error}")
|
|
response_bytes = msgpack.packb([[], []])
|
|
|
|
socket.send(response_bytes)
|
|
ser_end = time.time()
|
|
|
|
print(f"Serialize time: {ser_end - ser_start:.6f} seconds")
|
|
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
elif device.type == "mps":
|
|
torch.mps.synchronize()
|
|
e2e_end = time.time()
|
|
print(f"ZMQ E2E time: {e2e_end - e2e_start:.6f} seconds")
|
|
|
|
except zmq.Again:
|
|
print("ZMQ socket timeout, continuing to listen")
|
|
continue
|
|
except Exception as e:
|
|
print(f"Error in ZMQ server loop: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
try:
|
|
socket.send(msgpack.packb([[], []]))
|
|
except:
|
|
pass
|
|
|
|
# Start warmup and server threads
|
|
if enable_warmup and len(passages) > 0:
|
|
print(f"Warmup enabled: starting warmup thread")
|
|
warmup_thread = threading.Thread(target=client_warmup, args=(zmq_port,))
|
|
warmup_thread.daemon = True
|
|
warmup_thread.start()
|
|
else:
|
|
print(f"Warmup disabled or no passages available (enable_warmup={enable_warmup}, passages={len(passages)})")
|
|
|
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
|
zmq_thread.start()
|
|
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
|
|
|
# Keep the main thread alive
|
|
try:
|
|
while True:
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
print("HNSW Server shutting down...")
|
|
return
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="HNSW Embedding service")
|
|
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
|
|
parser.add_argument(
|
|
"--passages-file",
|
|
type=str,
|
|
help="JSON file containing passage ID to text mapping",
|
|
)
|
|
parser.add_argument(
|
|
"--embeddings-file",
|
|
type=str,
|
|
help="Pickle file containing pre-computed embeddings",
|
|
)
|
|
parser.add_argument("--use-fp16", action="store_true", default=False)
|
|
parser.add_argument("--use-int8", action="store_true", default=False)
|
|
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
|
|
parser.add_argument(
|
|
"--max-batch-size",
|
|
type=int,
|
|
default=128,
|
|
help="Maximum batch size before splitting",
|
|
)
|
|
parser.add_argument(
|
|
"--model-name",
|
|
type=str,
|
|
default="sentence-transformers/all-mpnet-base-v2",
|
|
help="Embedding model name",
|
|
)
|
|
parser.add_argument(
|
|
"--custom-max-length",
|
|
type=int,
|
|
default=None,
|
|
help="Override model's default max sequence length",
|
|
)
|
|
parser.add_argument(
|
|
"--distance-metric", type=str, default="mips", help="Distance metric to use"
|
|
)
|
|
parser.add_argument(
|
|
"--use-mlx",
|
|
action="store_true",
|
|
default=False,
|
|
help="Use MLX for model inference",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-warmup",
|
|
action="store_true",
|
|
default=False,
|
|
help="Disable warmup requests on server start",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create and start the HNSW embedding server
|
|
create_hnsw_embedding_server(
|
|
passages_file=args.passages_file,
|
|
embeddings_file=args.embeddings_file,
|
|
use_fp16=args.use_fp16,
|
|
use_int8=args.use_int8,
|
|
use_cuda_graphs=args.use_cuda_graphs,
|
|
zmq_port=args.zmq_port,
|
|
max_batch_size=args.max_batch_size,
|
|
model_name=args.model_name,
|
|
custom_max_length_param=args.custom_max_length,
|
|
distance_metric=args.distance_metric,
|
|
use_mlx=args.use_mlx,
|
|
enable_warmup=not args.disable_warmup,
|
|
)
|