fix(embedding-server): ensure shutdown-capable ZMQ threads create/bind their own REP sockets and poll with timeouts; fix undefined socket causing startup crash and CI hangs on Ubuntu 22.04

This commit is contained in:
Andy Lee
2025-08-13 13:53:08 -07:00
parent b381278c3e
commit 4b714f3b44
2 changed files with 211 additions and 157 deletions

View File

@@ -222,121 +222,144 @@ def create_diskann_embedding_server(
raise raise
def zmq_server_thread_with_shutdown(shutdown_event): def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.""" """ZMQ server thread that respects shutdown signal.
This creates its own REP socket, binds to zmq_port, and periodically
checks shutdown_event using recv timeouts to exit cleanly.
"""
logger.info("DiskANN ZMQ server thread started with shutdown support") logger.info("DiskANN ZMQ server thread started with shutdown support")
context = zmq.Context()
rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
# Set receive timeout so we can check shutdown_event periodically # Set receive timeout so we can check shutdown_event periodically
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
rep_socket.setsockopt(zmq.SNDTIMEO, 300000)
while not shutdown_event.is_set():
try:
e2e_start = time.time()
# REP socket receives single-part messages
message = socket.recv(zmq.NOBLOCK)
# Check for empty messages - REP socket requires response to every request
if not message:
logger.warning("Received empty message, sending empty response")
socket.send(b"")
continue
# Try protobuf first (same logic as original)
texts = []
is_text_request = False
try:
while not shutdown_event.is_set():
try: try:
req_proto = embedding_pb2.NodeEmbeddingRequest() e2e_start = time.time()
req_proto.ParseFromString(message) # REP socket receives single-part messages
node_ids = list(req_proto.node_ids) message = rep_socket.recv()
# Look up texts by node IDs # Check for empty messages - REP socket requires response to every request
for nid in node_ids: if not message:
try: logger.warning("Received empty message, sending empty response")
txt = passages.get_text(nid) rep_socket.send(b"")
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
except Exception:
# Fallback to msgpack for text requests
try:
import msgpack
request = msgpack.unpackb(message)
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(f"ZMQ received msgpack text request for {len(texts)} texts")
else:
raise ValueError("Not a valid msgpack text request")
except Exception:
logger.error("Both protobuf and msgpack parsing failed!")
# Send error response
resp_proto = embedding_pb2.NodeEmbeddingResponse()
socket.send(resp_proto.SerializeToString())
continue continue
# Process the request # Try protobuf first (same logic as original)
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) texts = []
logger.info(f"Computed embeddings shape: {embeddings.shape}") is_text_request = False
# Validation try:
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): req_proto = embedding_pb2.NodeEmbeddingRequest()
logger.error("NaN or Inf detected in embeddings!") req_proto.ParseFromString(message)
# Send error response node_ids = list(req_proto.node_ids)
# Look up texts by node IDs
for nid in node_ids:
try:
passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
except Exception:
# Fallback to msgpack for text requests
try:
import msgpack
request = msgpack.unpackb(message)
if isinstance(request, list) and all(
isinstance(item, str) for item in request
):
texts = request
is_text_request = True
logger.info(
f"ZMQ received msgpack text request for {len(texts)} texts"
)
else:
raise ValueError("Not a valid msgpack text request")
except Exception:
logger.error("Both protobuf and msgpack parsing failed!")
# Send error response
resp_proto = embedding_pb2.NodeEmbeddingResponse()
rep_socket.send(resp_proto.SerializeToString())
continue
# Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
logger.error("NaN or Inf detected in embeddings!")
# Send error response
if is_text_request:
import msgpack
response_data = msgpack.packb([])
else:
resp_proto = embedding_pb2.NodeEmbeddingResponse()
response_data = resp_proto.SerializeToString()
rep_socket.send(response_data)
continue
# Prepare response based on request type
if is_text_request: if is_text_request:
# For direct text requests, return msgpack
import msgpack import msgpack
response_data = msgpack.packb([]) response_data = msgpack.packb(embeddings.tolist())
else: else:
# For protobuf requests, return protobuf
resp_proto = embedding_pb2.NodeEmbeddingResponse() resp_proto = embedding_pb2.NodeEmbeddingResponse()
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
resp_proto.embeddings_data = hidden_contiguous.tobytes()
resp_proto.dimensions.append(hidden_contiguous.shape[0])
resp_proto.dimensions.append(hidden_contiguous.shape[1])
response_data = resp_proto.SerializeToString() response_data = resp_proto.SerializeToString()
socket.send(response_data)
# Send response back to the client
rep_socket.send(response_data)
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue continue
except Exception as e:
# Prepare response based on request type if not shutdown_event.is_set():
if is_text_request: logger.error(f"Error in ZMQ server loop: {e}")
# For direct text requests, return msgpack try:
import msgpack # Send error response for REP socket
resp_proto = embedding_pb2.NodeEmbeddingResponse()
response_data = msgpack.packb(embeddings.tolist()) rep_socket.send(resp_proto.SerializeToString())
else: except Exception:
# For protobuf requests, return protobuf pass
resp_proto = embedding_pb2.NodeEmbeddingResponse() else:
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32) logger.info("Shutdown in progress, ignoring ZMQ error")
break
resp_proto.embeddings_data = hidden_contiguous.tobytes() finally:
resp_proto.dimensions.append(hidden_contiguous.shape[0]) try:
resp_proto.dimensions.append(hidden_contiguous.shape[1]) rep_socket.close(0)
except Exception:
response_data = resp_proto.SerializeToString() pass
try:
# Send response back to the client context.term()
socket.send(response_data) except Exception:
pass
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
try:
# Send error response for REP socket
resp_proto = embedding_pb2.NodeEmbeddingResponse()
socket.send(resp_proto.SerializeToString())
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
logger.info("DiskANN ZMQ server thread exiting gracefully") logger.info("DiskANN ZMQ server thread exiting gracefully")

View File

@@ -241,77 +241,108 @@ def create_hnsw_embedding_server(
socket.send(msgpack.packb([[], []])) socket.send(msgpack.packb([[], []]))
def zmq_server_thread_with_shutdown(shutdown_event): def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal.""" """ZMQ server thread that respects shutdown signal.
Creates its own REP socket bound to zmq_port and polls with timeouts
to allow graceful shutdown.
"""
logger.info("ZMQ server thread started with shutdown support") logger.info("ZMQ server thread started with shutdown support")
# Set receive timeout so we can check shutdown_event periodically context = zmq.Context()
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout rep_socket = context.socket(zmq.REP)
rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
rep_socket.setsockopt(zmq.SNDTIMEO, 300000)
while not shutdown_event.is_set(): try:
try: while not shutdown_event.is_set():
e2e_start = time.time() try:
logger.debug("🔍 Waiting for ZMQ message...") e2e_start = time.time()
request_bytes = socket.recv(zmq.NOBLOCK) logger.debug("🔍 Waiting for ZMQ message...")
request_bytes = rep_socket.recv()
# Rest of the processing logic (same as original) # Rest of the processing logic (same as original)
request = msgpack.unpackb(request_bytes) request = msgpack.unpackb(request_bytes)
if len(request) == 1 and request[0] == "__QUERY_MODEL__": if len(request) == 1 and request[0] == "__QUERY_MODEL__":
response_bytes = msgpack.packb([model_name]) response_bytes = msgpack.packb([model_name])
socket.send(response_bytes) rep_socket.send(response_bytes)
continue continue
node_ids = request # Handle direct text embedding request
logger.info(f"ZMQ received {len(node_ids)} node IDs") if (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
continue
texts = [] node_ids = request if isinstance(request, list) else []
for nid in node_ids: logger.info(f"ZMQ received {len(node_ids)} node IDs")
try:
txt = passages.get_text(nid)
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) texts = []
logger.info( for nid in node_ids:
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" try:
) passage_data = passages.get_passage(str(nid))
txt = passage_data["text"]
if not txt:
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
texts.append(txt)
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
logger.error( logger.info(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
raise AssertionError()
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) if np.isnan(embeddings).any() or np.isinf(embeddings).any():
response_payload = [ logger.error(
list(hidden_contiguous_f32.shape), f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
hidden_contiguous_f32.flatten().tolist(), )
] raise AssertionError()
response_bytes = msgpack.packb(response_payload, use_single_float=True)
socket.send(response_bytes) hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
e2e_end = time.time() response_payload = [
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") list(hidden_contiguous_f32.shape),
hidden_contiguous_f32.flatten().tolist(),
]
response_bytes = msgpack.packb(response_payload, use_single_float=True)
except zmq.Again: rep_socket.send(response_bytes)
# Timeout - check shutdown_event and continue e2e_end = time.time()
continue logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except Exception as e:
if not shutdown_event.is_set(): except zmq.Again:
logger.error(f"Error in ZMQ server loop: {e}") # Timeout - check shutdown_event and continue
try: continue
socket.send(msgpack.packb([[], []])) except Exception as e:
except Exception: if not shutdown_event.is_set():
pass logger.error(f"Error in ZMQ server loop: {e}")
else: try:
logger.info("Shutdown in progress, ignoring ZMQ error") rep_socket.send(msgpack.packb([[], []]))
break except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error")
break
finally:
try:
rep_socket.close(0)
except Exception:
pass
try:
context.term()
except Exception:
pass
logger.info("ZMQ server thread exiting gracefully") logger.info("ZMQ server thread exiting gracefully")