diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 013ae5a..9a475d2 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -95,6 +95,8 @@ def create_hnsw_embedding_server( passage_sources.append(source_copy) passages = PassageManager(passage_sources) + # Use index dimensions from metadata for shaping fallback responses + embedding_dim: int = int(meta.get("dimensions", 0)) logger.info( f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" ) @@ -109,6 +111,9 @@ def create_hnsw_embedding_server( socket.setsockopt(zmq.RCVTIMEO, 300000) socket.setsockopt(zmq.SNDTIMEO, 300000) + # Track last request type for safe fallback responses on exceptions + last_request_type = "unknown" # one of: 'text', 'distance', 'embedding', 'unknown' + last_request_length = 0 while True: try: message_bytes = socket.recv() @@ -121,6 +126,8 @@ def create_hnsw_embedding_server( if isinstance(request_payload, list) and len(request_payload) > 0: # Check if this is a direct text request (list of strings) if all(isinstance(item, str) for item in request_payload): + last_request_type = "text" + last_request_length = len(request_payload) logger.info( f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode" ) @@ -145,43 +152,66 @@ def create_hnsw_embedding_server( ): node_ids = request_payload[0] query_vector = np.array(request_payload[1], dtype=np.float32) + last_request_type = "distance" + last_request_length = len(node_ids) logger.debug("Distance calculation request received") logger.debug(f" Node IDs: {node_ids}") logger.debug(f" Query vector dim: {len(query_vector)}") - # Get embeddings for node IDs - texts = [] - for nid in node_ids: + # Get embeddings for node IDs, tolerate missing IDs + texts: list[str] = [] + found_indices: list[int] = [] + for idx, nid in enumerate(node_ids): try: passage_data = passages.get_passage(str(nid)) - txt = passage_data["text"] - texts.append(txt) + txt = passage_data.get("text", "") + if isinstance(txt, str) and len(txt) > 0: + texts.append(txt) + found_indices.append(idx) + else: + logger.error(f"Empty text for passage ID {nid}") except KeyError: logger.error(f"Passage ID {nid} not found") - raise RuntimeError(f"FATAL: Passage with ID {nid} not found") except Exception as e: logger.error(f"Exception looking up passage ID {nid}: {e}") - raise - # Process embeddings - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + # Prepare full-length response distances with safe fallbacks + large_distance = 1e9 + response_distances = [large_distance] * len(node_ids) + + if texts: + try: + # Process embeddings only for found indices + embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) + + # Calculate distances for found embeddings only + if distance_metric == "l2": + partial_distances = np.sum( + np.square(embeddings - query_vector.reshape(1, -1)), axis=1 + ) + else: # mips or cosine + partial_distances = -np.dot(embeddings, query_vector) + + # Place computed distances back into the full response array + for pos, dval in zip( + found_indices, partial_distances.flatten().tolist() + ): + response_distances[pos] = float(dval) + except Exception as e: + logger.error( + f"Distance computation error, falling back to large distances: {e}" + ) + + # Always reply with exactly len(node_ids) distances + response_bytes = msgpack.packb([response_distances], use_single_float=True) + logger.debug( + f"Sending distance response with {len(response_distances)} distances (found={len(found_indices)})" ) - # Calculate distances - if distance_metric == "l2": - distances = np.sum( - np.square(embeddings - query_vector.reshape(1, -1)), axis=1 - ) - else: # mips or cosine - distances = -np.dot(embeddings, query_vector) - - response_payload = distances.flatten().tolist() - response_bytes = msgpack.packb([response_payload], use_single_float=True) - logger.debug(f"Sending distance response with {len(distances)} distances") - socket.send(response_bytes) e2e_end = time.time() logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") @@ -201,40 +231,61 @@ def create_hnsw_embedding_server( node_ids = request_payload[0] logger.debug(f"Request for {len(node_ids)} node embeddings") + last_request_type = "embedding" + last_request_length = len(node_ids) - # Look up texts by node IDs - texts = [] - for nid in node_ids: + # Allocate output buffer (B, D) and fill with zeros for robustness + if embedding_dim <= 0: + logger.error("Embedding dimension unknown; cannot serve embedding request") + dims = [0, 0] + data = [] + else: + dims = [len(node_ids), embedding_dim] + data = [0.0] * (dims[0] * dims[1]) + + # Look up texts by node IDs; compute embeddings where available + texts: list[str] = [] + found_indices: list[int] = [] + for idx, nid in enumerate(node_ids): try: passage_data = passages.get_passage(str(nid)) - txt = passage_data["text"] - if not txt: - raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") - texts.append(txt) + txt = passage_data.get("text", "") + if isinstance(txt, str) and len(txt) > 0: + texts.append(txt) + found_indices.append(idx) + else: + logger.error(f"Empty text for passage ID {nid}") except KeyError: - raise RuntimeError(f"FATAL: Passage with ID {nid} not found") + logger.error(f"Passage with ID {nid} not found") except Exception as e: logger.error(f"Exception looking up passage ID {nid}: {e}") - raise - # Process embeddings - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" - ) + if texts: + try: + # Process embeddings for found texts only + embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) - # Serialization and response - if np.isnan(embeddings).any() or np.isinf(embeddings).any(): - logger.error( - f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." - ) - raise AssertionError() + if np.isnan(embeddings).any() or np.isinf(embeddings).any(): + logger.error( + f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." + ) + dims = [0, embedding_dim] + data = [] + else: + # Copy computed embeddings into the correct positions + emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) + flat = emb_f32.flatten().tolist() + for j, pos in enumerate(found_indices): + start = pos * embedding_dim + end = start + embedding_dim + data[start:end] = flat[j * embedding_dim : (j + 1) * embedding_dim] + except Exception as e: + logger.error(f"Embedding computation error, returning zeros: {e}") - hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) - response_payload = [ - list(hidden_contiguous_f32.shape), - hidden_contiguous_f32.flatten().tolist(), - ] + response_payload = [dims, data] response_bytes = msgpack.packb(response_payload, use_single_float=True) socket.send(response_bytes) @@ -249,7 +300,22 @@ def create_hnsw_embedding_server( import traceback traceback.print_exc() - socket.send(msgpack.packb([[], []])) + # Fallback to a safe, minimal-structure response to avoid client crashes + if last_request_type == "distance": + # Return a vector of large distances with the expected length + fallback_len = max(0, int(last_request_length)) + large_distance = 1e9 + safe_response = [[large_distance] * fallback_len] + elif last_request_type == "embedding": + # Return an empty embedding block with known dimension if available + if embedding_dim > 0: + safe_response = [[0, embedding_dim], []] + else: + safe_response = [[0, 0], []] + else: + # Unknown request type: default to empty embedding structure + safe_response = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []] + socket.send(msgpack.packb(safe_response, use_single_float=True)) zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) zmq_thread.start()