From a7ad0bc3d63932b2f0d4884a21525467a7139620 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Wed, 13 Aug 2025 16:06:39 -0700 Subject: [PATCH] refactor(hnsw-server): remove duplicate legacy ZMQ thread; keep single shutdown-capable server implementation to reduce surface and avoid hangs --- .../hnsw_embedding_server.py | 152 +----------------- 1 file changed, 1 insertion(+), 151 deletions(-) diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index fb4f011..6690c51 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -93,157 +93,7 @@ def create_hnsw_embedding_server( f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" ) - def zmq_server_thread(): - """ZMQ server thread""" - context = zmq.Context() - socket = context.socket(zmq.REP) - socket.bind(f"tcp://*:{zmq_port}") - logger.info(f"HNSW ZMQ server listening on port {zmq_port}") - - socket.setsockopt(zmq.RCVTIMEO, 300000) - socket.setsockopt(zmq.SNDTIMEO, 300000) - - while True: - try: - message_bytes = socket.recv() - logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes") - - e2e_start = time.time() - request_payload = msgpack.unpackb(message_bytes) - - # Handle direct text embedding request - if isinstance(request_payload, list) and len(request_payload) > 0: - # Check if this is a direct text request (list of strings) - if all(isinstance(item, str) for item in request_payload): - logger.info( - f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode" - ) - - # Use unified embedding computation (now with model caching) - embeddings = compute_embeddings( - request_payload, model_name, mode=embedding_mode - ) - - response = embeddings.tolist() - socket.send(msgpack.packb(response)) - e2e_end = time.time() - logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") - continue - - # Handle distance calculation requests - if ( - isinstance(request_payload, list) - and len(request_payload) == 2 - and isinstance(request_payload[0], list) - and isinstance(request_payload[1], list) - ): - node_ids = request_payload[0] - query_vector = np.array(request_payload[1], dtype=np.float32) - - logger.debug("Distance calculation request received") - logger.debug(f" Node IDs: {node_ids}") - logger.debug(f" Query vector dim: {len(query_vector)}") - - # Get embeddings for node IDs - texts = [] - for nid in node_ids: - try: - passage_data = passages.get_passage(str(nid)) - txt = passage_data["text"] - texts.append(txt) - except KeyError: - logger.error(f"Passage ID {nid} not found") - raise RuntimeError(f"FATAL: Passage with ID {nid} not found") - except Exception as e: - logger.error(f"Exception looking up passage ID {nid}: {e}") - raise - - # Process embeddings - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" - ) - - # Calculate distances - if distance_metric == "l2": - distances = np.sum( - np.square(embeddings - query_vector.reshape(1, -1)), axis=1 - ) - else: # mips or cosine - distances = -np.dot(embeddings, query_vector) - - response_payload = distances.flatten().tolist() - response_bytes = msgpack.packb([response_payload], use_single_float=True) - logger.debug(f"Sending distance response with {len(distances)} distances") - - socket.send(response_bytes) - e2e_end = time.time() - logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") - continue - - # Standard embedding request (passage ID lookup) - if ( - not isinstance(request_payload, list) - or len(request_payload) != 1 - or not isinstance(request_payload[0], list) - ): - logger.error( - f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}" - ) - socket.send(msgpack.packb([[], []])) - continue - - node_ids = request_payload[0] - logger.debug(f"Request for {len(node_ids)} node embeddings") - - # Look up texts by node IDs - texts = [] - for nid in node_ids: - try: - passage_data = passages.get_passage(str(nid)) - txt = passage_data["text"] - if not txt: - raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") - texts.append(txt) - except KeyError: - raise RuntimeError(f"FATAL: Passage with ID {nid} not found") - except Exception as e: - logger.error(f"Exception looking up passage ID {nid}: {e}") - raise - - # Process embeddings - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" - ) - - # Serialization and response - if np.isnan(embeddings).any() or np.isinf(embeddings).any(): - logger.error( - f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." - ) - raise AssertionError() - - hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) - response_payload = [ - list(hidden_contiguous_f32.shape), - hidden_contiguous_f32.flatten().tolist(), - ] - response_bytes = msgpack.packb(response_payload, use_single_float=True) - - socket.send(response_bytes) - e2e_end = time.time() - logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") - - except zmq.Again: - logger.debug("ZMQ socket timeout, continuing to listen") - continue - except Exception as e: - logger.error(f"Error in ZMQ server loop: {e}") - import traceback - - traceback.print_exc() - socket.send(msgpack.packb([[], []])) + # (legacy ZMQ thread removed; using shutdown-capable server only) def zmq_server_thread_with_shutdown(shutdown_event): """ZMQ server thread that respects shutdown signal.