diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 882acbf..dac48e5 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -143,8 +143,6 @@ def create_hnsw_embedding_server( pass return str(nid) - # (legacy ZMQ thread removed; using shutdown-capable server only) - def zmq_server_thread_with_shutdown(shutdown_event): """ZMQ server thread that respects shutdown signal. @@ -158,225 +156,245 @@ def create_hnsw_embedding_server( rep_socket.bind(f"tcp://*:{zmq_port}") logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}") rep_socket.setsockopt(zmq.RCVTIMEO, 1000) - # Keep sends from blocking during shutdown; fail fast and drop on close rep_socket.setsockopt(zmq.SNDTIMEO, 1000) rep_socket.setsockopt(zmq.LINGER, 0) - # Track last request type/length for shape-correct fallbacks - last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown' + last_request_type = "unknown" last_request_length = 0 - try: - while not shutdown_event.is_set(): - try: - e2e_start = time.time() - logger.debug("🔍 Waiting for ZMQ message...") - request_bytes = rep_socket.recv() + def _build_safe_fallback(): + if last_request_type == "distance": + large_distance = 1e9 + fallback_len = max(0, int(last_request_length)) + return [[large_distance] * fallback_len] + if last_request_type == "embedding": + bsz = max(0, int(last_request_length)) + dim = max(0, int(embedding_dim)) + if dim > 0: + return [[bsz, dim], [0.0] * (bsz * dim)] + return [[0, 0], []] + if last_request_type == "text": + return [] + return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []] - # Rest of the processing logic (same as original) - request = msgpack.unpackb(request_bytes) + def _handle_request(request): + nonlocal last_request_type, last_request_length - if len(request) == 1 and request[0] == "__QUERY_MODEL__": - response_bytes = msgpack.packb([model_name]) - rep_socket.send(response_bytes) - continue + # Model query + if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__": + rep_socket.send(msgpack.packb([model_name])) + return - # Handle direct text embedding request - if ( - isinstance(request, list) - and request - and all(isinstance(item, str) for item in request) - ): - last_request_type = "text" - last_request_length = len(request) + # Direct text embedding + if ( + isinstance(request, list) + and request + and all(isinstance(item, str) for item in request) + ): + e2e_start = time.time() + last_request_type = "text" + last_request_length = len(request) + embeddings = compute_embeddings( + request, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) + rep_socket.send(msgpack.packb(embeddings.tolist())) + e2e_end = time.time() + logger.info( + f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s" + ) + return + + # Distance calculation: [[ids], [query_vector]] + if ( + isinstance(request, list) + and len(request) == 2 + and isinstance(request[0], list) + and isinstance(request[1], list) + ): + e2e_start = time.time() + node_ids = request[0] + if len(node_ids) == 1 and isinstance(node_ids[0], list): + node_ids = node_ids[0] + query_vector = np.array(request[1], dtype=np.float32) + last_request_type = "distance" + last_request_length = len(node_ids) + + logger.debug("Distance calculation request received") + logger.debug(f" Node IDs: {node_ids}") + logger.debug(f" Query vector dim: {len(query_vector)}") + + texts: list[str] = [] + found_indices: list[int] = [] + for idx, nid in enumerate(node_ids): + try: + passage_id = _map_node_id(nid) + passage_data = passages.get_passage(passage_id) + txt = passage_data.get("text", "") + if isinstance(txt, str) and len(txt) > 0: + texts.append(txt) + found_indices.append(idx) + else: + logger.error(f"Empty text for passage ID {passage_id}") + except KeyError: + logger.error(f"Passage ID {nid} not found") + except Exception as exc: + logger.error(f"Exception looking up passage ID {nid}: {exc}") + + large_distance = 1e9 + response_distances = [large_distance] * len(node_ids) + + if texts: + try: embeddings = compute_embeddings( - request, + texts, model_name, mode=embedding_mode, provider_options=PROVIDER_OPTIONS, ) - rep_socket.send(msgpack.packb(embeddings.tolist())) - e2e_end = time.time() - logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") - continue - - # Handle distance calculation request: [[ids], [query_vector]] - if ( - isinstance(request, list) - and len(request) == 2 - and isinstance(request[0], list) - and isinstance(request[1], list) - ): - node_ids = request[0] - # Handle nested [[ids]] shape defensively - if len(node_ids) == 1 and isinstance(node_ids[0], list): - node_ids = node_ids[0] - query_vector = np.array(request[1], dtype=np.float32) - last_request_type = "distance" - last_request_length = len(node_ids) - - logger.debug("Distance calculation request received") - logger.debug(f" Node IDs: {node_ids}") - logger.debug(f" Query vector dim: {len(query_vector)}") - - # Gather texts for found ids - texts: list[str] = [] - found_indices: list[int] = [] - for idx, nid in enumerate(node_ids): - try: - passage_id = _map_node_id(nid) - passage_data = passages.get_passage(passage_id) - txt = passage_data.get("text", "") - if isinstance(txt, str) and len(txt) > 0: - texts.append(txt) - found_indices.append(idx) - else: - logger.error(f"Empty text for passage ID {passage_id}") - except KeyError: - logger.error(f"Passage ID {nid} not found") - except Exception as e: - logger.error(f"Exception looking up passage ID {nid}: {e}") - - # Prepare full-length response with large sentinel values - large_distance = 1e9 - response_distances = [large_distance] * len(node_ids) - - if texts: - try: - embeddings = compute_embeddings( - texts, - model_name, - mode=embedding_mode, - provider_options=PROVIDER_OPTIONS, - ) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" - ) - if distance_metric == "l2": - partial = np.sum( - np.square(embeddings - query_vector.reshape(1, -1)), axis=1 - ) - else: # mips or cosine - partial = -np.dot(embeddings, query_vector) - - for pos, dval in zip(found_indices, partial.flatten().tolist()): - response_distances[pos] = float(dval) - except Exception as e: - logger.error(f"Distance computation error, using sentinels: {e}") - - # Send response in expected shape [[distances]] - rep_socket.send(msgpack.packb([response_distances], use_single_float=True)) - e2e_end = time.time() - logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") - continue - - # Fallback: treat as embedding-by-id request - if ( - isinstance(request, list) - and len(request) == 1 - and isinstance(request[0], list) - ): - node_ids = request[0] - elif isinstance(request, list): - node_ids = request - else: - node_ids = [] - last_request_type = "embedding" - last_request_length = len(node_ids) - logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch") - - # Preallocate zero-filled flat data for robustness - if embedding_dim <= 0: - dims = [0, 0] - flat_data: list[float] = [] - else: - dims = [len(node_ids), embedding_dim] - flat_data = [0.0] * (dims[0] * dims[1]) - - # Collect texts for found ids - texts: list[str] = [] - found_indices: list[int] = [] - for idx, nid in enumerate(node_ids): - try: - passage_id = _map_node_id(nid) - passage_data = passages.get_passage(passage_id) - txt = passage_data.get("text", "") - if isinstance(txt, str) and len(txt) > 0: - texts.append(txt) - found_indices.append(idx) - else: - logger.error(f"Empty text for passage ID {passage_id}") - except KeyError: - logger.error(f"Passage with ID {nid} not found") - except Exception as e: - logger.error(f"Exception looking up passage ID {nid}: {e}") - - if texts: - try: - embeddings = compute_embeddings( - texts, - model_name, - mode=embedding_mode, - provider_options=PROVIDER_OPTIONS, - ) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) + if distance_metric == "l2": + partial = np.sum( + np.square(embeddings - query_vector.reshape(1, -1)), axis=1 ) + else: + partial = -np.dot(embeddings, query_vector) - if np.isnan(embeddings).any() or np.isinf(embeddings).any(): - logger.error( - f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." - ) - dims = [0, embedding_dim] - flat_data = [] - else: - emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) - flat = emb_f32.flatten().tolist() - for j, pos in enumerate(found_indices): - start = pos * embedding_dim - end = start + embedding_dim - if end <= len(flat_data): - flat_data[start:end] = flat[ - j * embedding_dim : (j + 1) * embedding_dim - ] - except Exception as e: - logger.error(f"Embedding computation error, returning zeros: {e}") + for pos, dval in zip(found_indices, partial.flatten().tolist()): + response_distances[pos] = float(dval) + except Exception as exc: + logger.error(f"Distance computation error, using sentinels: {exc}") - response_payload = [dims, flat_data] - response_bytes = msgpack.packb(response_payload, use_single_float=True) + rep_socket.send( + msgpack.packb([response_distances], use_single_float=True) + ) + e2e_end = time.time() + logger.info( + f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s" + ) + return - rep_socket.send(response_bytes) - e2e_end = time.time() - logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") + # Embedding-by-id fallback + if ( + isinstance(request, list) + and len(request) == 1 + and isinstance(request[0], list) + ): + node_ids = request[0] + elif isinstance(request, list): + node_ids = request + else: + node_ids = [] + e2e_start = time.time() + last_request_type = "embedding" + last_request_length = len(node_ids) + logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch") + + if embedding_dim <= 0: + dims = [0, 0] + flat_data: list[float] = [] + else: + dims = [len(node_ids), embedding_dim] + flat_data = [0.0] * (dims[0] * dims[1]) + + texts = [] + found_indices = [] + for idx, nid in enumerate(node_ids): + try: + passage_id = _map_node_id(nid) + passage_data = passages.get_passage(passage_id) + txt = passage_data.get("text", "") + if isinstance(txt, str) and len(txt) > 0: + texts.append(txt) + found_indices.append(idx) + else: + logger.error(f"Empty text for passage ID {passage_id}") + except KeyError: + logger.error(f"Passage with ID {nid} not found") + except Exception as exc: + logger.error(f"Exception looking up passage ID {nid}: {exc}") + + if texts: + try: + embeddings = compute_embeddings( + texts, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) + + if np.isnan(embeddings).any() or np.isinf(embeddings).any(): + logger.error( + f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." + ) + dims = [0, embedding_dim] + flat_data = [] + else: + emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) + flat = emb_f32.flatten().tolist() + for j, pos in enumerate(found_indices): + start = pos * embedding_dim + end = start + embedding_dim + if end <= len(flat_data): + flat_data[start:end] = flat[ + j * embedding_dim : (j + 1) * embedding_dim + ] + except Exception as exc: + logger.error(f"Embedding computation error, returning zeros: {exc}") + + response_payload = [dims, flat_data] + rep_socket.send( + msgpack.packb(response_payload, use_single_float=True) + ) + e2e_end = time.time() + logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") + + try: + while not shutdown_event.is_set(): + try: + logger.debug("🔍 Waiting for ZMQ message...") + request_bytes = rep_socket.recv() except zmq.Again: - # Timeout - check shutdown_event and continue continue - except Exception as e: - if not shutdown_event.is_set(): - logger.error(f"Error in ZMQ server loop: {e}") - # Shape-correct fallback - try: - if last_request_type == "distance": - large_distance = 1e9 - fallback_len = max(0, int(last_request_length)) - safe = [[large_distance] * fallback_len] - elif last_request_type == "embedding": - bsz = max(0, int(last_request_length)) - dim = max(0, int(embedding_dim)) - safe = ( - [[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []] - ) - elif last_request_type == "text": - safe = [] # direct text embeddings expectation is a flat list - else: - safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []] - rep_socket.send(msgpack.packb(safe, use_single_float=True)) - except Exception: - pass - else: + + try: + request = msgpack.unpackb(request_bytes) + except Exception as exc: + if shutdown_event.is_set(): logger.info("Shutdown in progress, ignoring ZMQ error") break + logger.error(f"Error unpacking ZMQ message: {exc}") + try: + safe = _build_safe_fallback() + rep_socket.send( + msgpack.packb(safe, use_single_float=True) + ) + except Exception: + pass + continue + + try: + _handle_request(request) + except Exception as exc: + if shutdown_event.is_set(): + logger.info("Shutdown in progress, ignoring ZMQ error") + break + logger.error(f"Error in ZMQ server loop: {exc}") + try: + safe = _build_safe_fallback() + rep_socket.send( + msgpack.packb(safe, use_single_float=True) + ) + except Exception: + pass finally: try: rep_socket.close(0)