fix(hnsw-server): robust ZMQ responses to prevent size mismatch and segfault in CI

This commit is contained in:
Andy Lee
2025-08-13 14:53:46 -07:00
parent c994635af6
commit 909d3cc6a8

View File

@@ -95,6 +95,8 @@ def create_hnsw_embedding_server(
passage_sources.append(source_copy)
passages = PassageManager(passage_sources)
# Use index dimensions from metadata for shaping fallback responses
embedding_dim: int = int(meta.get("dimensions", 0))
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
@@ -109,6 +111,9 @@ def create_hnsw_embedding_server(
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
# Track last request type for safe fallback responses on exceptions
last_request_type = "unknown" # one of: 'text', 'distance', 'embedding', 'unknown'
last_request_length = 0
while True:
try:
message_bytes = socket.recv()
@@ -121,6 +126,8 @@ def create_hnsw_embedding_server(
if isinstance(request_payload, list) and len(request_payload) > 0:
# Check if this is a direct text request (list of strings)
if all(isinstance(item, str) for item in request_payload):
last_request_type = "text"
last_request_length = len(request_payload)
logger.info(
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
)
@@ -145,43 +152,66 @@ def create_hnsw_embedding_server(
):
node_ids = request_payload[0]
query_vector = np.array(request_payload[1], dtype=np.float32)
last_request_type = "distance"
last_request_length = len(node_ids)
logger.debug("Distance calculation request received")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Get embeddings for node IDs
texts = []
for nid in node_ids:
# Get embeddings for node IDs, tolerate missing IDs
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
texts.append(txt)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError:
logger.error(f"Passage ID {nid} not found")
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
# Prepare full-length response distances with safe fallbacks
large_distance = 1e9
response_distances = [large_distance] * len(node_ids)
if texts:
try:
# Process embeddings only for found indices
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Calculate distances for found embeddings only
if distance_metric == "l2":
partial_distances = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
partial_distances = -np.dot(embeddings, query_vector)
# Place computed distances back into the full response array
for pos, dval in zip(
found_indices, partial_distances.flatten().tolist()
):
response_distances[pos] = float(dval)
except Exception as e:
logger.error(
f"Distance computation error, falling back to large distances: {e}"
)
# Always reply with exactly len(node_ids) distances
response_bytes = msgpack.packb([response_distances], use_single_float=True)
logger.debug(
f"Sending distance response with {len(response_distances)} distances (found={len(found_indices)})"
)
# Calculate distances
if distance_metric == "l2":
distances = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
)
else: # mips or cosine
distances = -np.dot(embeddings, query_vector)
response_payload = distances.flatten().tolist()
response_bytes = msgpack.packb([response_payload], use_single_float=True)
logger.debug(f"Sending distance response with {len(distances)} distances")
socket.send(response_bytes)
e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
@@ -201,40 +231,61 @@ def create_hnsw_embedding_server(
node_ids = request_payload[0]
logger.debug(f"Request for {len(node_ids)} node embeddings")
last_request_type = "embedding"
last_request_length = len(node_ids)
# Look up texts by node IDs
texts = []
for nid in node_ids:
# Allocate output buffer (B, D) and fill with zeros for robustness
if embedding_dim <= 0:
logger.error("Embedding dimension unknown; cannot serve embedding request")
dims = [0, 0]
data = []
else:
dims = [len(node_ids), embedding_dim]
data = [0.0] * (dims[0] * dims[1])
# Look up texts by node IDs; compute embeddings where available
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
txt = passage_data.get("text", "")
if isinstance(txt, str) and len(txt) > 0:
texts.append(txt)
found_indices.append(idx)
else:
logger.error(f"Empty text for passage ID {nid}")
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
logger.error(f"Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
if texts:
try:
# Process embeddings for found texts only
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Serialization and response
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
raise AssertionError()
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
dims = [0, embedding_dim]
data = []
else:
# Copy computed embeddings into the correct positions
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
flat = emb_f32.flatten().tolist()
for j, pos in enumerate(found_indices):
start = pos * embedding_dim
end = start + embedding_dim
data[start:end] = flat[j * embedding_dim : (j + 1) * embedding_dim]
except Exception as e:
logger.error(f"Embedding computation error, returning zeros: {e}")
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
response_payload = [
list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(),
]
response_payload = [dims, data]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
socket.send(response_bytes)
@@ -249,7 +300,22 @@ def create_hnsw_embedding_server(
import traceback
traceback.print_exc()
socket.send(msgpack.packb([[], []]))
# Fallback to a safe, minimal-structure response to avoid client crashes
if last_request_type == "distance":
# Return a vector of large distances with the expected length
fallback_len = max(0, int(last_request_length))
large_distance = 1e9
safe_response = [[large_distance] * fallback_len]
elif last_request_type == "embedding":
# Return an empty embedding block with known dimension if available
if embedding_dim > 0:
safe_response = [[0, embedding_dim], []]
else:
safe_response = [[0, 0], []]
else:
# Unknown request type: default to empty embedding structure
safe_response = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
socket.send(msgpack.packb(safe_response, use_single_float=True))
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()