diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py index ea28e21..1d5896b 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_embedding_server.py @@ -222,121 +222,144 @@ def create_diskann_embedding_server( raise def zmq_server_thread_with_shutdown(shutdown_event): - """ZMQ server thread that respects shutdown signal.""" + """ZMQ server thread that respects shutdown signal. + + This creates its own REP socket, binds to zmq_port, and periodically + checks shutdown_event using recv timeouts to exit cleanly. + """ logger.info("DiskANN ZMQ server thread started with shutdown support") + context = zmq.Context() + rep_socket = context.socket(zmq.REP) + rep_socket.bind(f"tcp://*:{zmq_port}") + logger.info(f"DiskANN ZMQ REP server listening on port {zmq_port}") + # Set receive timeout so we can check shutdown_event periodically - socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout - - while not shutdown_event.is_set(): - try: - e2e_start = time.time() - # REP socket receives single-part messages - message = socket.recv(zmq.NOBLOCK) - - # Check for empty messages - REP socket requires response to every request - if not message: - logger.warning("Received empty message, sending empty response") - socket.send(b"") - continue - - # Try protobuf first (same logic as original) - texts = [] - is_text_request = False + rep_socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout + rep_socket.setsockopt(zmq.SNDTIMEO, 300000) + try: + while not shutdown_event.is_set(): try: - req_proto = embedding_pb2.NodeEmbeddingRequest() - req_proto.ParseFromString(message) - node_ids = list(req_proto.node_ids) + e2e_start = time.time() + # REP socket receives single-part messages + message = rep_socket.recv() - # Look up texts by node IDs - for nid in node_ids: - try: - txt = passages.get_text(nid) - if not txt: - raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") - texts.append(txt) - except KeyError: - raise RuntimeError(f"FATAL: Passage with ID {nid} not found") - - logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs") - except Exception: - # Fallback to msgpack for text requests - try: - import msgpack - - request = msgpack.unpackb(message) - if isinstance(request, list) and all( - isinstance(item, str) for item in request - ): - texts = request - is_text_request = True - logger.info(f"ZMQ received msgpack text request for {len(texts)} texts") - else: - raise ValueError("Not a valid msgpack text request") - except Exception: - logger.error("Both protobuf and msgpack parsing failed!") - # Send error response - resp_proto = embedding_pb2.NodeEmbeddingResponse() - socket.send(resp_proto.SerializeToString()) + # Check for empty messages - REP socket requires response to every request + if not message: + logger.warning("Received empty message, sending empty response") + rep_socket.send(b"") continue - # Process the request - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) - logger.info(f"Computed embeddings shape: {embeddings.shape}") + # Try protobuf first (same logic as original) + texts = [] + is_text_request = False - # Validation - if np.isnan(embeddings).any() or np.isinf(embeddings).any(): - logger.error("NaN or Inf detected in embeddings!") - # Send error response + try: + req_proto = embedding_pb2.NodeEmbeddingRequest() + req_proto.ParseFromString(message) + node_ids = list(req_proto.node_ids) + + # Look up texts by node IDs + for nid in node_ids: + try: + passage_data = passages.get_passage(str(nid)) + txt = passage_data["text"] + if not txt: + raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") + texts.append(txt) + except KeyError: + raise RuntimeError(f"FATAL: Passage with ID {nid} not found") + + logger.info(f"ZMQ received protobuf request for {len(node_ids)} node IDs") + except Exception: + # Fallback to msgpack for text requests + try: + import msgpack + + request = msgpack.unpackb(message) + if isinstance(request, list) and all( + isinstance(item, str) for item in request + ): + texts = request + is_text_request = True + logger.info( + f"ZMQ received msgpack text request for {len(texts)} texts" + ) + else: + raise ValueError("Not a valid msgpack text request") + except Exception: + logger.error("Both protobuf and msgpack parsing failed!") + # Send error response + resp_proto = embedding_pb2.NodeEmbeddingResponse() + rep_socket.send(resp_proto.SerializeToString()) + continue + + # Process the request + embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + logger.info(f"Computed embeddings shape: {embeddings.shape}") + + # Validation + if np.isnan(embeddings).any() or np.isinf(embeddings).any(): + logger.error("NaN or Inf detected in embeddings!") + # Send error response + if is_text_request: + import msgpack + + response_data = msgpack.packb([]) + else: + resp_proto = embedding_pb2.NodeEmbeddingResponse() + response_data = resp_proto.SerializeToString() + rep_socket.send(response_data) + continue + + # Prepare response based on request type if is_text_request: + # For direct text requests, return msgpack import msgpack - response_data = msgpack.packb([]) + response_data = msgpack.packb(embeddings.tolist()) else: + # For protobuf requests, return protobuf resp_proto = embedding_pb2.NodeEmbeddingResponse() + hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32) + + resp_proto.embeddings_data = hidden_contiguous.tobytes() + resp_proto.dimensions.append(hidden_contiguous.shape[0]) + resp_proto.dimensions.append(hidden_contiguous.shape[1]) + response_data = resp_proto.SerializeToString() - socket.send(response_data) + + # Send response back to the client + rep_socket.send(response_data) + + e2e_end = time.time() + logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") + + except zmq.Again: + # Timeout - check shutdown_event and continue continue - - # Prepare response based on request type - if is_text_request: - # For direct text requests, return msgpack - import msgpack - - response_data = msgpack.packb(embeddings.tolist()) - else: - # For protobuf requests, return protobuf - resp_proto = embedding_pb2.NodeEmbeddingResponse() - hidden_contiguous = np.ascontiguousarray(embeddings, dtype=np.float32) - - resp_proto.embeddings_data = hidden_contiguous.tobytes() - resp_proto.dimensions.append(hidden_contiguous.shape[0]) - resp_proto.dimensions.append(hidden_contiguous.shape[1]) - - response_data = resp_proto.SerializeToString() - - # Send response back to the client - socket.send(response_data) - - e2e_end = time.time() - logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") - - except zmq.Again: - # Timeout - check shutdown_event and continue - continue - except Exception as e: - if not shutdown_event.is_set(): - logger.error(f"Error in ZMQ server loop: {e}") - try: - # Send error response for REP socket - resp_proto = embedding_pb2.NodeEmbeddingResponse() - socket.send(resp_proto.SerializeToString()) - except Exception: - pass - else: - logger.info("Shutdown in progress, ignoring ZMQ error") - break + except Exception as e: + if not shutdown_event.is_set(): + logger.error(f"Error in ZMQ server loop: {e}") + try: + # Send error response for REP socket + resp_proto = embedding_pb2.NodeEmbeddingResponse() + rep_socket.send(resp_proto.SerializeToString()) + except Exception: + pass + else: + logger.info("Shutdown in progress, ignoring ZMQ error") + break + finally: + try: + rep_socket.close(0) + except Exception: + pass + try: + context.term() + except Exception: + pass logger.info("DiskANN ZMQ server thread exiting gracefully") diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 49b13d0..2d31994 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -241,77 +241,108 @@ def create_hnsw_embedding_server( socket.send(msgpack.packb([[], []])) def zmq_server_thread_with_shutdown(shutdown_event): - """ZMQ server thread that respects shutdown signal.""" + """ZMQ server thread that respects shutdown signal. + + Creates its own REP socket bound to zmq_port and polls with timeouts + to allow graceful shutdown. + """ logger.info("ZMQ server thread started with shutdown support") - # Set receive timeout so we can check shutdown_event periodically - socket.setsockopt(zmq.RCVTIMEO, 1000) # 1 second timeout + context = zmq.Context() + rep_socket = context.socket(zmq.REP) + rep_socket.bind(f"tcp://*:{zmq_port}") + logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}") + rep_socket.setsockopt(zmq.RCVTIMEO, 1000) + rep_socket.setsockopt(zmq.SNDTIMEO, 300000) - while not shutdown_event.is_set(): - try: - e2e_start = time.time() - logger.debug("🔍 Waiting for ZMQ message...") - request_bytes = socket.recv(zmq.NOBLOCK) + try: + while not shutdown_event.is_set(): + try: + e2e_start = time.time() + logger.debug("🔍 Waiting for ZMQ message...") + request_bytes = rep_socket.recv() - # Rest of the processing logic (same as original) - request = msgpack.unpackb(request_bytes) + # Rest of the processing logic (same as original) + request = msgpack.unpackb(request_bytes) - if len(request) == 1 and request[0] == "__QUERY_MODEL__": - response_bytes = msgpack.packb([model_name]) - socket.send(response_bytes) - continue + if len(request) == 1 and request[0] == "__QUERY_MODEL__": + response_bytes = msgpack.packb([model_name]) + rep_socket.send(response_bytes) + continue - node_ids = request - logger.info(f"ZMQ received {len(node_ids)} node IDs") + # Handle direct text embedding request + if ( + isinstance(request, list) + and request + and all(isinstance(item, str) for item in request) + ): + embeddings = compute_embeddings(request, model_name, mode=embedding_mode) + rep_socket.send(msgpack.packb(embeddings.tolist())) + e2e_end = time.time() + logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") + continue - texts = [] - for nid in node_ids: - try: - txt = passages.get_text(nid) - if not txt: - raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") - texts.append(txt) - except KeyError: - raise RuntimeError(f"FATAL: Passage with ID {nid} not found") - except Exception as e: - logger.error(f"Exception looking up passage ID {nid}: {e}") - raise + node_ids = request if isinstance(request, list) else [] + logger.info(f"ZMQ received {len(node_ids)} node IDs") - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" - ) + texts = [] + for nid in node_ids: + try: + passage_data = passages.get_passage(str(nid)) + txt = passage_data["text"] + if not txt: + raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") + texts.append(txt) + except KeyError: + raise RuntimeError(f"FATAL: Passage with ID {nid} not found") + except Exception as e: + logger.error(f"Exception looking up passage ID {nid}: {e}") + raise - if np.isnan(embeddings).any() or np.isinf(embeddings).any(): - logger.error( - f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." + embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" ) - raise AssertionError() - hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) - response_payload = [ - list(hidden_contiguous_f32.shape), - hidden_contiguous_f32.flatten().tolist(), - ] - response_bytes = msgpack.packb(response_payload, use_single_float=True) + if np.isnan(embeddings).any() or np.isinf(embeddings).any(): + logger.error( + f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." + ) + raise AssertionError() - socket.send(response_bytes) - e2e_end = time.time() - logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") + hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) + response_payload = [ + list(hidden_contiguous_f32.shape), + hidden_contiguous_f32.flatten().tolist(), + ] + response_bytes = msgpack.packb(response_payload, use_single_float=True) - except zmq.Again: - # Timeout - check shutdown_event and continue - continue - except Exception as e: - if not shutdown_event.is_set(): - logger.error(f"Error in ZMQ server loop: {e}") - try: - socket.send(msgpack.packb([[], []])) - except Exception: - pass - else: - logger.info("Shutdown in progress, ignoring ZMQ error") - break + rep_socket.send(response_bytes) + e2e_end = time.time() + logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") + + except zmq.Again: + # Timeout - check shutdown_event and continue + continue + except Exception as e: + if not shutdown_event.is_set(): + logger.error(f"Error in ZMQ server loop: {e}") + try: + rep_socket.send(msgpack.packb([[], []])) + except Exception: + pass + else: + logger.info("Shutdown in progress, ignoring ZMQ error") + break + finally: + try: + rep_socket.close(0) + except Exception: + pass + try: + context.term() + except Exception: + pass logger.info("ZMQ server thread exiting gracefully")