fix(embedding-server): ensure shutdown-capable ZMQ threads create/bind their own REP sockets and poll with timeouts; fix undefined socket causing startup crash and CI hangs on Ubuntu 22.04
This commit is contained in:
@@ -222,121 +222,144 @@ def create_diskann_embedding_server(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
"""ZMQ server thread that respects shutdown signal."""
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
|
This creates its own REP socket, binds to zmq_port, and periodically
|
||||||
|
checks shutdown_event using recv timeouts to exit cleanly.
|
||||||
|
"""
|
||||||
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
logger.info("DiskANN ZMQ server thread started with shutdown support")
|
||||||
|
|
||||||
|
context = zmq.Context()
|
||||||
|
rep_socket = context.socket(zmq.REP)
|
||||||
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}")
|
||||||
|
|
||||||
# Set receive timeout so we can check shutdown_event periodically
|
# Set receive timeout so we can check shutdown_event periodically
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
||||||
|
rep_socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
while not shutdown_event.is_set():
|
|
||||||
try:
|
|
||||||
e2e_start = time.time()
|
|
||||||
# REP socket receives single-part messages
|
|
||||||
message = socket.recv(zmq.NOBLOCK)
|
|
||||||
|
|
||||||
# Check for empty messages - REP socket requires response to every request
|
|
||||||
if not message:
|
|
||||||
logger.warning("Received empty message, sending empty response")
|
|
||||||
socket.send(b"")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Try protobuf first (same logic as original)
|
|
||||||
texts = []
|
|
||||||
is_text_request = False
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not shutdown_event.is_set():
|
||||||
try:
|
try:
|
||||||
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
e2e_start = time.time()
|
||||||
req_proto.ParseFromString(message)
|
# REP socket receives single-part messages
|
||||||
node_ids = list(req_proto.node_ids)
|
message = rep_socket.recv()
|
||||||
|
|
||||||
# Look up texts by node IDs
|
# Check for empty messages - REP socket requires response to every request
|
||||||
for nid in node_ids:
|
if not message:
|
||||||
try:
|
logger.warning("Received empty message, sending empty response")
|
||||||
txt = passages.get_text(nid)
|
rep_socket.send(b"")
|
||||||
if not txt:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
||||||
texts.append(txt)
|
|
||||||
except KeyError:
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
|
|
||||||
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
|
||||||
except Exception:
|
|
||||||
# Fallback to msgpack for text requests
|
|
||||||
try:
|
|
||||||
import msgpack
|
|
||||||
|
|
||||||
request = msgpack.unpackb(message)
|
|
||||||
if isinstance(request, list) and all(
|
|
||||||
isinstance(item, str) for item in request
|
|
||||||
):
|
|
||||||
texts = request
|
|
||||||
is_text_request = True
|
|
||||||
logger.info(f"ZMQ received msgpack text request for {len(texts)} texts")
|
|
||||||
else:
|
|
||||||
raise ValueError("Not a valid msgpack text request")
|
|
||||||
except Exception:
|
|
||||||
logger.error("Both protobuf and msgpack parsing failed!")
|
|
||||||
# Send error response
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
socket.send(resp_proto.SerializeToString())
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Process the request
|
# Try protobuf first (same logic as original)
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
texts = []
|
||||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
is_text_request = False
|
||||||
|
|
||||||
# Validation
|
try:
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
req_proto = embedding_pb2.NodeEmbeddingRequest()
|
||||||
logger.error("NaN or Inf detected in embeddings!")
|
req_proto.ParseFromString(message)
|
||||||
# Send error response
|
node_ids = list(req_proto.node_ids)
|
||||||
|
|
||||||
|
# Look up texts by node IDs
|
||||||
|
for nid in node_ids:
|
||||||
|
try:
|
||||||
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
|
||||||
|
logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs")
|
||||||
|
except Exception:
|
||||||
|
# Fallback to msgpack for text requests
|
||||||
|
try:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
request = msgpack.unpackb(message)
|
||||||
|
if isinstance(request, list) and all(
|
||||||
|
isinstance(item, str) for item in request
|
||||||
|
):
|
||||||
|
texts = request
|
||||||
|
is_text_request = True
|
||||||
|
logger.info(
|
||||||
|
f"ZMQ received msgpack text request for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not a valid msgpack text request")
|
||||||
|
except Exception:
|
||||||
|
logger.error("Both protobuf and msgpack parsing failed!")
|
||||||
|
# Send error response
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
rep_socket.send(resp_proto.SerializeToString())
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process the request
|
||||||
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
|
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
logger.error("NaN or Inf detected in embeddings!")
|
||||||
|
# Send error response
|
||||||
|
if is_text_request:
|
||||||
|
import msgpack
|
||||||
|
|
||||||
|
response_data = msgpack.packb([])
|
||||||
|
else:
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
response_data = resp_proto.SerializeToString()
|
||||||
|
rep_socket.send(response_data)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare response based on request type
|
||||||
if is_text_request:
|
if is_text_request:
|
||||||
|
# For direct text requests, return msgpack
|
||||||
import msgpack
|
import msgpack
|
||||||
|
|
||||||
response_data = msgpack.packb([])
|
response_data = msgpack.packb(embeddings.tolist())
|
||||||
else:
|
else:
|
||||||
|
# For protobuf requests, return protobuf
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
|
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
||||||
|
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
||||||
|
|
||||||
response_data = resp_proto.SerializeToString()
|
response_data = resp_proto.SerializeToString()
|
||||||
socket.send(response_data)
|
|
||||||
|
# Send response back to the client
|
||||||
|
rep_socket.send(response_data)
|
||||||
|
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
except zmq.Again:
|
||||||
|
# Timeout - check shutdown_event and continue
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
# Prepare response based on request type
|
if not shutdown_event.is_set():
|
||||||
if is_text_request:
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
# For direct text requests, return msgpack
|
try:
|
||||||
import msgpack
|
# Send error response for REP socket
|
||||||
|
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
||||||
response_data = msgpack.packb(embeddings.tolist())
|
rep_socket.send(resp_proto.SerializeToString())
|
||||||
else:
|
except Exception:
|
||||||
# For protobuf requests, return protobuf
|
pass
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
else:
|
||||||
hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32)
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
resp_proto.embeddings_data = hidden_contiguous.tobytes()
|
finally:
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[0])
|
try:
|
||||||
resp_proto.dimensions.append(hidden_contiguous.shape[1])
|
rep_socket.close(0)
|
||||||
|
except Exception:
|
||||||
response_data = resp_proto.SerializeToString()
|
pass
|
||||||
|
try:
|
||||||
# Send response back to the client
|
context.term()
|
||||||
socket.send(response_data)
|
except Exception:
|
||||||
|
pass
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
|
|
||||||
except zmq.Again:
|
|
||||||
# Timeout - check shutdown_event and continue
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
if not shutdown_event.is_set():
|
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
|
||||||
try:
|
|
||||||
# Send error response for REP socket
|
|
||||||
resp_proto = embedding_pb2.NodeEmbeddingResponse()
|
|
||||||
socket.send(resp_proto.SerializeToString())
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
|
||||||
break
|
|
||||||
|
|
||||||
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
logger.info("DiskANN ZMQ server thread exiting gracefully")
|
||||||
|
|
||||||
|
|||||||
@@ -241,77 +241,108 @@ def create_hnsw_embedding_server(
|
|||||||
socket.send(msgpack.packb([[], []]))
|
socket.send(msgpack.packb([[], []]))
|
||||||
|
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
"""ZMQ server thread that respects shutdown signal."""
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
|
Creates its own REP socket bound to zmq_port and polls with timeouts
|
||||||
|
to allow graceful shutdown.
|
||||||
|
"""
|
||||||
logger.info("ZMQ server thread started with shutdown support")
|
logger.info("ZMQ server thread started with shutdown support")
|
||||||
|
|
||||||
# Set receive timeout so we can check shutdown_event periodically
|
context = zmq.Context()
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout
|
rep_socket = context.socket(zmq.REP)
|
||||||
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
|
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||||
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
|
rep_socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
|
|
||||||
while not shutdown_event.is_set():
|
try:
|
||||||
try:
|
while not shutdown_event.is_set():
|
||||||
e2e_start = time.time()
|
try:
|
||||||
logger.debug("🔍 Waiting for ZMQ message...")
|
e2e_start = time.time()
|
||||||
request_bytes = socket.recv(zmq.NOBLOCK)
|
logger.debug("🔍 Waiting for ZMQ message...")
|
||||||
|
request_bytes = rep_socket.recv()
|
||||||
|
|
||||||
# Rest of the processing logic (same as original)
|
# Rest of the processing logic (same as original)
|
||||||
request = msgpack.unpackb(request_bytes)
|
request = msgpack.unpackb(request_bytes)
|
||||||
|
|
||||||
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||||
response_bytes = msgpack.packb([model_name])
|
response_bytes = msgpack.packb([model_name])
|
||||||
socket.send(response_bytes)
|
rep_socket.send(response_bytes)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
node_ids = request
|
# Handle direct text embedding request
|
||||||
logger.info(f"ZMQ received {len(node_ids)} node IDs")
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and request
|
||||||
|
and all(isinstance(item, str) for item in request)
|
||||||
|
):
|
||||||
|
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
||||||
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
continue
|
||||||
|
|
||||||
texts = []
|
node_ids = request if isinstance(request, list) else []
|
||||||
for nid in node_ids:
|
logger.info(f"ZMQ received {len(node_ids)} node IDs")
|
||||||
try:
|
|
||||||
txt = passages.get_text(nid)
|
|
||||||
if not txt:
|
|
||||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
|
||||||
texts.append(txt)
|
|
||||||
except KeyError:
|
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
texts = []
|
||||||
logger.info(
|
for nid in node_ids:
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
try:
|
||||||
)
|
passage_data = passages.get_passage(str(nid))
|
||||||
|
txt = passage_data["text"]
|
||||||
|
if not txt:
|
||||||
|
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||||
|
texts.append(txt)
|
||||||
|
except KeyError:
|
||||||
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
logger.error(
|
logger.info(
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
raise AssertionError()
|
|
||||||
|
|
||||||
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
response_payload = [
|
logger.error(
|
||||||
list(hidden_contiguous_f32.shape),
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
hidden_contiguous_f32.flatten().tolist(),
|
)
|
||||||
]
|
raise AssertionError()
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
|
||||||
|
|
||||||
socket.send(response_bytes)
|
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
e2e_end = time.time()
|
response_payload = [
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
list(hidden_contiguous_f32.shape),
|
||||||
|
hidden_contiguous_f32.flatten().tolist(),
|
||||||
|
]
|
||||||
|
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||||
|
|
||||||
except zmq.Again:
|
rep_socket.send(response_bytes)
|
||||||
# Timeout - check shutdown_event and continue
|
e2e_end = time.time()
|
||||||
continue
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
except Exception as e:
|
|
||||||
if not shutdown_event.is_set():
|
except zmq.Again:
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
# Timeout - check shutdown_event and continue
|
||||||
try:
|
continue
|
||||||
socket.send(msgpack.packb([[], []]))
|
except Exception as e:
|
||||||
except Exception:
|
if not shutdown_event.is_set():
|
||||||
pass
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
else:
|
try:
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
rep_socket.send(msgpack.packb([[], []]))
|
||||||
break
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
rep_socket.close(0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
context.term()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
logger.info("ZMQ server thread exiting gracefully")
|
logger.info("ZMQ server thread exiting gracefully")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user