refactor: check if current emb_server has correct passages/embedder

This commit is contained in:
Andy Lee
2025-07-13 22:33:33 -07:00
parent 77ac013a74
commit 3b5a185e60
5 changed files with 915 additions and 229 deletions

View File

@@ -17,10 +17,12 @@ import msgpack
import json
from pathlib import Path
from typing import Dict, Any, Optional, Union
import sys
RED = "\033[91m"
RESET = "\033[0m"
def is_similarity_metric():
"""
Check if the metric type is similarity-based (like inner product).
@@ -28,22 +30,27 @@ def is_similarity_metric():
"""
return True # 1 is METRIC_INNER_PRODUCT in FAISS
# Function for E5-style average pooling
import torch
from torch import Tensor
import torch.nn.functional as F
def e5_average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
class SimplePassageLoader:
"""
Simple passage loader that replaces config.py dependencies
"""
def __init__(self, passages_data: Optional[Dict[str, Any]] = None):
self.passages_data = passages_data or {}
self._meta_path = ""
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID"""
str_id = str(passage_id)
@@ -52,54 +59,57 @@ class SimplePassageLoader:
else:
# Return empty text for missing passages
return {"text": ""}
def __len__(self) -> int:
return len(self.passages_data)
def keys(self):
return self.passages_data.keys()
def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
"""
Load passages using metadata file with PassageManager for lazy loading
"""
# Load metadata to get passage sources
with open(meta_file, 'r') as f:
with open(meta_file, "r") as f:
meta = json.load(f)
# Import PassageManager dynamically to avoid circular imports
import sys
import importlib.util
# Find the leann package directory relative to this file
current_dir = Path(__file__).parent
leann_core_path = current_dir.parent.parent / "leann-core" / "src"
sys.path.insert(0, str(leann_core_path))
try:
from leann.api import PassageManager
passage_manager = PassageManager(meta['passage_sources'])
passage_manager = PassageManager(meta["passage_sources"])
finally:
sys.path.pop(0)
# Load label map
# Load label map
passages_dir = Path(meta_file).parent
label_map_file = passages_dir / "leann.labels.map"
if label_map_file.exists():
import pickle
with open(label_map_file, 'rb') as f:
with open(label_map_file, "rb") as f:
label_map = pickle.load(f)
print(f"Loaded label map with {len(label_map)} entries")
else:
raise FileNotFoundError(f"Label map file not found: {label_map_file}")
print(f"Initialized lazy passage loading for {len(label_map)} passages")
class LazyPassageLoader(SimplePassageLoader):
def __init__(self, passage_manager, label_map):
self.passage_manager = passage_manager
self.label_map = label_map
# Initialize parent with empty data
super().__init__({})
def __getitem__(self, passage_id: Union[str, int]) -> Dict[str, str]:
"""Get passage by ID with lazy loading"""
try:
@@ -118,12 +128,16 @@ def load_passages_from_metadata(meta_file: str) -> SimplePassageLoader:
except Exception as e:
print(f"DEBUG: Exception getting passage {passage_id}: {e}")
return {"text": ""}
def __len__(self) -> int:
return len(self.label_map)
def keys(self):
return self.label_map.keys()
return LazyPassageLoader(passage_manager, label_map)
def create_hnsw_embedding_server(
passages_file: Optional[str] = None,
passages_data: Optional[Dict[str, str]] = None,
@@ -139,7 +153,7 @@ def create_hnsw_embedding_server(
):
"""
Create and start a ZMQ-based embedding server for HNSW backend.
Args:
passages_file: Path to JSON file containing passage ID -> text mapping
passages_data: Direct passage data dict (alternative to passages_file)
@@ -156,14 +170,14 @@ def create_hnsw_embedding_server(
print(f"Loading tokenizer for {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Tokenizer loaded successfully!")
# Device setup
mps_available = hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
cuda_available = torch.cuda.is_available()
print(f"MPS available: {mps_available}")
print(f"CUDA available: {cuda_available}")
if cuda_available:
device = torch.device("cuda")
print("Using CUDA device")
@@ -173,7 +187,7 @@ def create_hnsw_embedding_server(
else:
device = torch.device("cpu")
print("Using CPU device (no GPU acceleration available)")
# Load model to the appropriate device
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Loading model {model_name}... (this may take a while if downloading)")
@@ -182,9 +196,10 @@ def create_hnsw_embedding_server(
# Check port availability
import socket
def check_port(port):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}")
@@ -196,8 +211,14 @@ def create_hnsw_embedding_server(
model = torch.compile(model)
print(f"Using FP16 precision with model: {model_name}")
elif use_int8:
print("- Using TorchAO for Int8 dynamic activation and Int8 weight quantization")
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
print(
"- Using TorchAO for Int8 dynamic activation and Int8 weight quantization"
)
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt8WeightConfig,
)
quantize_(model, Int8DynamicActivationInt8WeightConfig())
model = torch.compile(model)
model.eval()
@@ -209,8 +230,10 @@ def create_hnsw_embedding_server(
print(f"Using provided passages data: {len(passages)} passages")
elif passages_file:
# Check if it's a metadata file or a single passages file
if passages_file.endswith('.meta.json'):
if passages_file.endswith(".meta.json"):
passages = load_passages_from_metadata(passages_file)
# Store the meta path for future reference
passages._meta_path = passages_file
else:
# Try to find metadata file in same directory
passages_dir = Path(passages_file).parent
@@ -220,8 +243,12 @@ def create_hnsw_embedding_server(
passages = load_passages_from_metadata(str(meta_files[0]))
else:
# Fallback to original single file loading (will cause warnings)
print("WARNING: No metadata file found, using single file loading (may cause missing passage warnings)")
passages = SimplePassageLoader() # Use empty loader to avoid massive warnings
print(
"WARNING: No metadata file found, using single file loading (may cause missing passage warnings)"
)
passages = (
SimplePassageLoader()
) # Use empty loader to avoid massive warnings
else:
passages = SimplePassageLoader()
print("No passages provided, using empty loader")
@@ -238,12 +265,13 @@ def create_hnsw_embedding_server(
class DeviceTimer:
"""Device event-based timer for accurate timing."""
def __init__(self, name="", device=device):
self.name = name
self.device = device
self.start_time = 0
self.end_time = 0
if cuda_available:
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
@@ -289,30 +317,31 @@ def create_hnsw_embedding_server(
_is_e5_model = "e5" in model_name.lower()
_is_bge_model = "bge" in model_name.lower()
batch_size = len(texts_batch)
# Validate no empty texts
for i, text in enumerate(texts_batch):
if not text or text.strip() == "":
raise RuntimeError(f"FATAL: Empty text at batch index {i}, ID: {ids_batch[i] if i < len(ids_batch) else 'unknown'}")
# Allow empty texts to pass through (remove validation)
# E5 model preprocessing
if _is_e5_model:
processed_texts_batch = [f"passage: {text}" for text in texts_batch]
else:
processed_texts_batch = texts_batch
# Set max length
if _is_e5_model:
current_max_length = custom_max_length_param if custom_max_length_param is not None else 512
current_max_length = (
custom_max_length_param if custom_max_length_param is not None else 512
)
else:
current_max_length = custom_max_length_param if custom_max_length_param is not None else 256
current_max_length = (
custom_max_length_param if custom_max_length_param is not None else 256
)
tokenize_timer = DeviceTimer("tokenization (batch)", device)
to_device_timer = DeviceTimer("transfer to device (batch)", device)
embed_timer = DeviceTimer("embedding (batch)", device)
pool_timer = DeviceTimer("pooling (batch)", device)
norm_timer = DeviceTimer("normalization (batch)", device)
with tokenize_timer.timing():
encoded_batch = tokenizer(
processed_texts_batch,
@@ -322,48 +351,71 @@ def create_hnsw_embedding_server(
return_tensors="pt",
return_token_type_ids=False,
)
seq_length = encoded_batch["input_ids"].size(1)
with to_device_timer.timing():
enc = {k: v.to(device) for k, v in encoded_batch.items()}
with torch.no_grad():
with embed_timer.timing():
out = model(enc["input_ids"], enc["attention_mask"])
with pool_timer.timing():
if _is_bge_model:
pooled_embeddings = out.last_hidden_state[:, 0]
elif not hasattr(out, 'last_hidden_state'):
elif not hasattr(out, "last_hidden_state"):
if isinstance(out, torch.Tensor) and len(out.shape) == 2:
pooled_embeddings = out
else:
print(f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}")
hidden_dim = getattr(model.config, 'hidden_size', 384 if _is_e5_model else 768)
pooled_embeddings = torch.zeros((batch_size, hidden_dim), device=device, dtype=enc["input_ids"].dtype if hasattr(enc["input_ids"], "dtype") else torch.float32)
print(
f"{RED}ERROR: Cannot determine how to pool. Output shape: {out.shape if isinstance(out, torch.Tensor) else 'N/A'}{RESET}"
)
hidden_dim = getattr(
model.config, "hidden_size", 384 if _is_e5_model else 768
)
pooled_embeddings = torch.zeros(
(batch_size, hidden_dim),
device=device,
dtype=enc["input_ids"].dtype
if hasattr(enc["input_ids"], "dtype")
else torch.float32,
)
elif _is_e5_model:
pooled_embeddings = e5_average_pool(out.last_hidden_state, enc['attention_mask'])
pooled_embeddings = e5_average_pool(
out.last_hidden_state, enc["attention_mask"]
)
else:
hidden_states = out.last_hidden_state
mask_expanded = enc["attention_mask"].unsqueeze(-1).expand(hidden_states.size()).float()
mask_expanded = (
enc["attention_mask"]
.unsqueeze(-1)
.expand(hidden_states.size())
.float()
)
sum_embeddings = torch.sum(hidden_states * mask_expanded, 1)
sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
pooled_embeddings = sum_embeddings / sum_mask
final_embeddings = pooled_embeddings
if _is_e5_model or _is_bge_model:
with norm_timer.timing():
final_embeddings = F.normalize(pooled_embeddings, p=2, dim=1)
if torch.isnan(final_embeddings).any() or torch.isinf(final_embeddings).any():
print(f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}")
print(
f"{RED}!!! In process_batch: NaN or Inf detected in final_embeddings! "
f"Model: {model_name}, E5: {_is_e5_model}. IDs (sample): {ids_batch[:5]}...{RESET}"
)
dim_size = final_embeddings.shape[-1]
error_output = torch.zeros((batch_size, dim_size), device='cpu', dtype=torch.float32).numpy()
print(f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}")
error_output = torch.zeros(
(batch_size, dim_size), device="cpu", dtype=torch.float32
).numpy()
print(
f"{RED}Returning zero embeddings of shape ({batch_size}, {dim_size}) due to NaN/Inf.{RESET}"
)
return error_output
return final_embeddings.cpu().numpy()
def client_warmup(zmq_port):
@@ -371,7 +423,7 @@ def create_hnsw_embedding_server(
time.sleep(2)
print(f"Performing client-side warmup with model {model_name}...")
sample_ids = ["1", "2", "3", "4", "5"]
try:
context = zmq.Context()
socket = context.socket(zmq.REQ)
@@ -379,12 +431,12 @@ def create_hnsw_embedding_server(
socket.setsockopt(zmq.RCVTIMEO, 30000)
socket.setsockopt(zmq.SNDTIMEO, 30000)
try:
try:
ids_to_send = [int(x) for x in sample_ids]
except ValueError:
except ValueError:
ids_to_send = []
if not ids_to_send:
if not ids_to_send:
print("Skipping warmup send.")
return
@@ -392,14 +444,18 @@ def create_hnsw_embedding_server(
request_bytes = msgpack.packb(request_payload)
for i in range(3):
print(f"Sending warmup request {i+1}/3 via ZMQ (MessagePack)...")
print(f"Sending warmup request {i + 1}/3 via ZMQ (MessagePack)...")
socket.send(request_bytes)
response_bytes = socket.recv()
response_payload = msgpack.unpackb(response_bytes)
dimensions = response_payload[0]
embeddings_count = dimensions[0] if dimensions and len(dimensions) > 0 else 0
print(f"Warmup request {i+1}/3 successful, received {embeddings_count} embeddings")
embeddings_count = (
dimensions[0] if dimensions and len(dimensions) > 0 else 0
)
print(
f"Warmup request {i + 1}/3 successful, received {embeddings_count} embeddings"
)
time.sleep(0.1)
print("Client-side MessagePack ZMQ warmup complete")
@@ -410,6 +466,7 @@ def create_hnsw_embedding_server(
def zmq_server_thread():
"""ZMQ server thread"""
nonlocal passages, model, tokenizer, model_name
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}")
@@ -428,94 +485,277 @@ def create_hnsw_embedding_server(
try:
request_payload = msgpack.unpackb(message_bytes)
print(f"DEBUG: Raw request_payload: {request_payload}")
print(f"DEBUG: request_payload type: {type(request_payload)}")
if isinstance(request_payload, list):
print(f"DEBUG: request_payload length: {len(request_payload)}")
for i, item in enumerate(request_payload):
print(
f"DEBUG: request_payload[{i}]: {type(item)} - {item if len(str(item)) < 100 else str(item)[:100] + '...'}"
)
# Handle control messages for meta path and model management
if isinstance(request_payload, list) and len(request_payload) >= 1:
if request_payload[0] == "__QUERY_META_PATH__":
# Return the current meta path being used by the server
current_meta_path = (
getattr(passages, "_meta_path", "")
if hasattr(passages, "_meta_path")
else ""
)
response = [current_meta_path]
socket.send(msgpack.packb(response))
continue
elif (
request_payload[0] == "__UPDATE_META_PATH__"
and len(request_payload) >= 2
):
# Update the server's meta path and reload passages
new_meta_path = request_payload[1]
try:
print(
f"INFO: Updating server meta path to: {new_meta_path}"
)
# Reload passages from the new meta file
passages = load_passages_from_metadata(new_meta_path)
# Store the meta path for future queries
passages._meta_path = new_meta_path
response = ["SUCCESS"]
print(
f"INFO: Successfully updated meta path and reloaded {len(passages)} passages"
)
except Exception as e:
print(f"ERROR: Failed to update meta path: {e}")
response = ["FAILED", str(e)]
socket.send(msgpack.packb(response))
continue
elif request_payload[0] == "__QUERY_MODEL__":
# Return the current model being used by the server
response = [model_name]
socket.send(msgpack.packb(response))
continue
elif (
request_payload[0] == "__UPDATE_MODEL__"
and len(request_payload) >= 2
):
# Update the server's embedding model
new_model_name = request_payload[1]
try:
print(
f"INFO: Updating server model from {model_name} to: {new_model_name}"
)
# Clean up old model to free memory
print("INFO: Releasing old model from memory...")
old_model = model
old_tokenizer = tokenizer
# Load new tokenizer first
print(f"Loading new tokenizer for {new_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(
new_model_name, use_fast=True
)
# Load new model
print(f"Loading new model {new_model_name}...")
model = AutoModel.from_pretrained(new_model_name)
model.to(device)
model.eval()
# Now safely delete old model after new one is loaded
del old_model
del old_tokenizer
# Clear GPU cache if available
if device.type == "cuda":
torch.cuda.empty_cache()
print("INFO: Cleared CUDA cache")
elif device.type == "mps":
torch.mps.empty_cache()
print("INFO: Cleared MPS cache")
# Update model name
model_name = new_model_name
# Force garbage collection
import gc
gc.collect()
print("INFO: Memory cleanup completed")
response = ["SUCCESS"]
print(
f"INFO: Successfully updated model to: {new_model_name}"
)
except Exception as e:
print(f"ERROR: Failed to update model: {e}")
response = ["FAILED", str(e)]
socket.send(msgpack.packb(response))
continue
# Handle distance calculation requests
if isinstance(request_payload, list) and len(request_payload) == 2 and isinstance(request_payload[0], list) and isinstance(request_payload[1], list):
if (
isinstance(request_payload, list)
and len(request_payload) == 2
and isinstance(request_payload[0], list)
and isinstance(request_payload[1], list)
):
node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32)
print(f"Request for distance calculation: {len(node_ids)} nodes, query vector dim: {len(query_vector)}")
print("DEBUG: Distance calculation request received")
print(f" Node IDs: {node_ids}")
print(f" Query vector dim: {len(query_vector)}")
print(f" Passages loaded: {len(passages)}")
# Get embeddings for node IDs
texts = []
missing_ids = []
with lookup_timer.timing():
for nid in node_ids:
print(f"DEBUG: Looking up passage ID {nid}")
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} returned empty text")
txt = txtinfo["text"]
print(f"DEBUG: Found text for ID {nid}, length: {len(txt)}")
texts.append(txt)
try:
txtinfo = passages[nid]
if txtinfo is None:
print(
f"ERROR: Passage with ID {nid} returned None"
)
print(f"ERROR: txtinfo: {txtinfo}")
raise RuntimeError(
f"FATAL: Passage with ID {nid} returned None"
)
txt = txtinfo[
"text"
] # Allow empty text to pass through
print(
f"DEBUG: Found text for ID {nid}, length: {len(txt)}"
)
texts.append(txt)
except KeyError:
print(
f"ERROR: Passage ID {nid} not found in passages dict"
)
print(
f"ERROR: Available passage IDs: {list(passages.keys())[:10]}..."
)
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e:
print(
f"ERROR: Exception looking up passage ID {nid}: {e}"
)
raise
lookup_timer.print_elapsed()
# Process embeddings in chunks if needed
all_node_embeddings = []
total_size = len(texts)
if total_size > max_batch_size:
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
embeddings_chunk = process_batch(
chunk_texts, chunk_ids, missing_ids
)
all_node_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
node_embeddings = np.vstack(all_node_embeddings)
else:
node_embeddings = process_batch(texts, node_ids, missing_ids)
node_embeddings = process_batch(
texts, node_ids, missing_ids
)
# Calculate distances
query_tensor = torch.tensor(query_vector, device=device).float()
node_embeddings_tensor = torch.tensor(node_embeddings, device=device).float()
node_embeddings_tensor = torch.tensor(
node_embeddings, device=device
).float()
calc_timer = DeviceTimer("distance calculation", device)
with calc_timer.timing():
with torch.no_grad():
if distance_metric == "l2":
node_embeddings_np = node_embeddings_tensor.cpu().numpy().astype(np.float32)
query_np = query_tensor.cpu().numpy().astype(np.float32)
distances = np.sum(np.square(node_embeddings_np - query_np.reshape(1, -1)), axis=1)
else: # mips or cosine
node_embeddings_np = node_embeddings_tensor.cpu().numpy()
node_embeddings_np = (
node_embeddings_tensor.cpu()
.numpy()
.astype(np.float32)
)
query_np = (
query_tensor.cpu().numpy().astype(np.float32)
)
distances = np.sum(
np.square(
node_embeddings_np - query_np.reshape(1, -1)
),
axis=1,
)
else: # mips or cosine
node_embeddings_np = (
node_embeddings_tensor.cpu().numpy()
)
query_np = query_tensor.cpu().numpy()
distances = -np.dot(node_embeddings_np, query_np)
calc_timer.print_elapsed()
try:
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True)
print(f"Sending distance response with {len(distances)} distances")
response_bytes = msgpack.packb(
[response_payload], use_single_float=True
)
print(
f"Sending distance response with {len(distances)} distances"
)
except Exception as pack_error:
print(f"Error packing MessagePack distance response: {pack_error}")
print(
f"ERROR: Error packing MessagePack distance response: {pack_error}"
)
print(f"ERROR: distances shape: {distances.shape}")
print(f"ERROR: distances dtype: {distances.dtype}")
print(f"ERROR: distances content: {distances}")
print(f"ERROR: node_ids: {node_ids}")
print(f"ERROR: query_vector shape: {query_vector.shape}")
# Still return empty for now but with full error info
response_bytes = msgpack.packb([[]])
socket.send(response_bytes)
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
e2e_end = time.time()
print(f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds")
print(
f"Distance calculation E2E time: {e2e_end - e2e_start:.6f} seconds"
)
continue
# Standard embedding request
if not isinstance(request_payload, list) or len(request_payload) != 1 or not isinstance(request_payload[0], list):
print(f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}")
if (
not isinstance(request_payload, list)
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
print(
f"Error: Invalid MessagePack request format. Expected [[ids...]], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
node_ids = request_payload[0]
print(f"Request for {len(node_ids)} node embeddings")
except Exception as unpack_error:
print(f"Error unpacking MessagePack request: {unpack_error}")
socket.send(msgpack.packb([[], []]))
@@ -529,11 +769,15 @@ def create_hnsw_embedding_server(
try:
txtinfo = passages[nid]
if txtinfo is None or txtinfo["text"] == "":
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found - failing fast"
)
else:
txt = txtinfo["text"]
except (KeyError, IndexError):
raise RuntimeError(f"FATAL: Passage with ID {nid} not found - failing fast")
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found - failing fast"
)
texts.append(txt)
lookup_timer.print_elapsed()
@@ -542,27 +786,35 @@ def create_hnsw_embedding_server(
# Process in chunks
total_size = len(texts)
print(f"Total batch size: {total_size}, max_batch_size: {max_batch_size}")
print(
f"Total batch size: {total_size}, max_batch_size: {max_batch_size}"
)
all_embeddings = []
if total_size > max_batch_size:
print(f"Splitting batch of size {total_size} into chunks of {max_batch_size}")
print(
f"Splitting batch of size {total_size} into chunks of {max_batch_size}"
)
for i in range(0, total_size, max_batch_size):
end_idx = min(i + max_batch_size, total_size)
print(f"Processing chunk {i//max_batch_size + 1}/{(total_size + max_batch_size - 1)//max_batch_size}: items {i} to {end_idx-1}")
print(
f"Processing chunk {i // max_batch_size + 1}/{(total_size + max_batch_size - 1) // max_batch_size}: items {i} to {end_idx - 1}"
)
chunk_texts = texts[i:end_idx]
chunk_ids = node_ids[i:end_idx]
embeddings_chunk = process_batch(chunk_texts, chunk_ids, missing_ids)
embeddings_chunk = process_batch(
chunk_texts, chunk_ids, missing_ids
)
all_embeddings.append(embeddings_chunk)
if cuda_available:
torch.cuda.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()
hidden = np.vstack(all_embeddings)
print(f"Combined embeddings shape: {hidden.shape}")
else:
@@ -571,22 +823,30 @@ def create_hnsw_embedding_server(
# Serialization and response
ser_start = time.time()
print(f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}")
print(
f"DEBUG zmq_server_thread: Final 'hidden' array | Shape: {hidden.shape} | Dtype: {hidden.dtype} | Has NaN/Inf: {np.isnan(hidden).any() or np.isinf(hidden).any()}"
)
if np.isnan(hidden).any() or np.isinf(hidden).any():
print(f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
f"Requested IDs (sample): {node_ids[:5]}...{RESET}")
print(
f"{RED}!!! ERROR: NaN or Inf detected in final 'hidden' numpy array BEFORE sending! "
f"Requested IDs (sample): {node_ids[:5]}...{RESET}"
)
assert False
try:
hidden_contiguous_f32 = np.ascontiguousarray(hidden, dtype=np.float32)
hidden_contiguous_f32 = np.ascontiguousarray(
hidden, dtype=np.float32
)
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist()
hidden_contiguous_f32.flatten().tolist(),
]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
response_bytes = msgpack.packb(
response_payload, use_single_float=True
)
except Exception as pack_error:
print(f"Error packing MessagePack response: {pack_error}")
response_bytes = msgpack.packb([[], []])
print(f"Error packing MessagePack response: {pack_error}")
response_bytes = msgpack.packb([[], []])
socket.send(response_bytes)
ser_end = time.time()
@@ -606,8 +866,9 @@ def create_hnsw_embedding_server(
except Exception as e:
print(f"Error in ZMQ server loop: {e}")
import traceback
traceback.print_exc()
try:
try:
socket.send(msgpack.packb([[], []]))
except:
pass
@@ -621,7 +882,7 @@ def create_hnsw_embedding_server(
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while True:
@@ -634,17 +895,41 @@ def create_hnsw_embedding_server(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="HNSW Embedding service")
parser.add_argument("--zmq-port", type=int, default=5555, help="ZMQ port to run on")
parser.add_argument("--passages-file", type=str, help="JSON file containing passage ID to text mapping")
parser.add_argument("--embeddings-file", type=str, help="Pickle file containing pre-computed embeddings")
parser.add_argument(
"--passages-file",
type=str,
help="JSON file containing passage ID to text mapping",
)
parser.add_argument(
"--embeddings-file",
type=str,
help="Pickle file containing pre-computed embeddings",
)
parser.add_argument("--use-fp16", action="store_true", default=False)
parser.add_argument("--use-int8", action="store_true", default=False)
parser.add_argument("--use-cuda-graphs", action="store_true", default=False)
parser.add_argument("--max-batch-size", type=int, default=128, help="Maximum batch size before splitting")
parser.add_argument("--model-name", type=str, default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name")
parser.add_argument("--custom-max-length", type=int, default=None, help="Override model's default max sequence length")
parser.add_argument("--distance-metric", type=str, default="mips", help="Distance metric to use")
parser.add_argument(
"--max-batch-size",
type=int,
default=128,
help="Maximum batch size before splitting",
)
parser.add_argument(
"--model-name",
type=str,
default="sentence-transformers/all-mpnet-base-v2",
help="Embedding model name",
)
parser.add_argument(
"--custom-max-length",
type=int,
default=None,
help="Override model's default max sequence length",
)
parser.add_argument(
"--distance-metric", type=str, default="mips", help="Distance metric to use"
)
args = parser.parse_args()
# Create and start the HNSW embedding server
@@ -659,4 +944,4 @@ if __name__ == "__main__":
model_name=args.model_name,
custom_max_length_param=args.custom_max_length,
distance_metric=args.distance_metric,
)
)