diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index 2d31994..523522f 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -84,6 +84,11 @@ def create_hnsw_embedding_server( # Let PassageManager handle path resolution uniformly passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file) + # Dimension from metadata for shaping responses + try: + embedding_dim: int = int(meta.get("dimensions", 0)) + except Exception: + embedding_dim = 0 logger.info( f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" ) @@ -255,6 +260,10 @@ def create_hnsw_embedding_server( rep_socket.setsockopt(zmq.RCVTIMEO, 1000) rep_socket.setsockopt(zmq.SNDTIMEO, 300000) + # Track last request type/length for shape-correct fallbacks + last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown' + last_request_length = 0 + try: while not shutdown_event.is_set(): try: @@ -276,45 +285,135 @@ def create_hnsw_embedding_server( and request and all(isinstance(item, str) for item in request) ): + last_request_type = "text" + last_request_length = len(request) embeddings = compute_embeddings(request, model_name, mode=embedding_mode) rep_socket.send(msgpack.packb(embeddings.tolist())) e2e_end = time.time() logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") continue - node_ids = request if isinstance(request, list) else [] - logger.info(f"ZMQ received {len(node_ids)} node IDs") + # Handle distance calculation request: [[ids], [query_vector]] + if ( + isinstance(request, list) + and len(request) == 2 + and isinstance(request[0], list) + and isinstance(request[1], list) + ): + node_ids = request[0] + query_vector = np.array(request[1], dtype=np.float32) + last_request_type = "distance" + last_request_length = len(node_ids) - texts = [] - for nid in node_ids: + logger.debug("Distance calculation request received") + logger.debug(f" Node IDs: {node_ids}") + logger.debug(f" Query vector dim: {len(query_vector)}") + + # Gather texts for found ids + texts: list[str] = [] + found_indices: list[int] = [] + for idx, nid in enumerate(node_ids): + try: + passage_data = passages.get_passage(str(nid)) + txt = passage_data.get("text", "") + if isinstance(txt, str) and len(txt) > 0: + texts.append(txt) + found_indices.append(idx) + else: + logger.error(f"Empty text for passage ID {nid}") + except KeyError: + logger.error(f"Passage ID {nid} not found") + except Exception as e: + logger.error(f"Exception looking up passage ID {nid}: {e}") + + # Prepare full-length response with large sentinel values + large_distance = 1e9 + response_distances = [large_distance] * len(node_ids) + + if texts: + try: + embeddings = compute_embeddings( + texts, model_name, mode=embedding_mode + ) + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) + if distance_metric == "l2": + partial = np.sum( + np.square(embeddings - query_vector.reshape(1, -1)), axis=1 + ) + else: # mips or cosine + partial = -np.dot(embeddings, query_vector) + + for pos, dval in zip(found_indices, partial.flatten().tolist()): + response_distances[pos] = float(dval) + except Exception as e: + logger.error(f"Distance computation error, using sentinels: {e}") + + # Send response in expected shape [[distances]] + rep_socket.send(msgpack.packb([response_distances], use_single_float=True)) + e2e_end = time.time() + logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") + continue + + # Fallback: treat as embedding-by-id request [[ids]] + node_ids = request if isinstance(request, list) else [] + last_request_type = "embedding" + last_request_length = len(node_ids) + logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch") + + # Preallocate zero-filled flat data for robustness + if embedding_dim <= 0: + dims = [0, 0] + flat_data: list[float] = [] + else: + dims = [len(node_ids), embedding_dim] + flat_data = [0.0] * (dims[0] * dims[1]) + + # Collect texts for found ids + texts: list[str] = [] + found_indices: list[int] = [] + for idx, nid in enumerate(node_ids): try: passage_data = passages.get_passage(str(nid)) - txt = passage_data["text"] - if not txt: - raise RuntimeError(f"FATAL: Empty text for passage ID {nid}") - texts.append(txt) + txt = passage_data.get("text", "") + if isinstance(txt, str) and len(txt) > 0: + texts.append(txt) + found_indices.append(idx) + else: + logger.error(f"Empty text for passage ID {nid}") except KeyError: - raise RuntimeError(f"FATAL: Passage with ID {nid} not found") + logger.error(f"Passage with ID {nid} not found") except Exception as e: logger.error(f"Exception looking up passage ID {nid}: {e}") - raise - embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" - ) + if texts: + try: + embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) - if np.isnan(embeddings).any() or np.isinf(embeddings).any(): - logger.error( - f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." - ) - raise AssertionError() + if np.isnan(embeddings).any() or np.isinf(embeddings).any(): + logger.error( + f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." + ) + dims = [0, embedding_dim] + flat_data = [] + else: + emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) + flat = emb_f32.flatten().tolist() + for j, pos in enumerate(found_indices): + start = pos * embedding_dim + end = start + embedding_dim + if end <= len(flat_data): + flat_data[start:end] = flat[ + j * embedding_dim : (j + 1) * embedding_dim + ] + except Exception as e: + logger.error(f"Embedding computation error, returning zeros: {e}") - hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32) - response_payload = [ - list(hidden_contiguous_f32.shape), - hidden_contiguous_f32.flatten().tolist(), - ] + response_payload = [dims, flat_data] response_bytes = msgpack.packb(response_payload, use_single_float=True) rep_socket.send(response_bytes) @@ -327,8 +426,23 @@ def create_hnsw_embedding_server( except Exception as e: if not shutdown_event.is_set(): logger.error(f"Error in ZMQ server loop: {e}") + # Shape-correct fallback try: - rep_socket.send(msgpack.packb([[], []])) + if last_request_type == "distance": + large_distance = 1e9 + fallback_len = max(0, int(last_request_length)) + safe = [[large_distance] * fallback_len] + elif last_request_type == "embedding": + bsz = max(0, int(last_request_length)) + dim = max(0, int(embedding_dim)) + safe = ( + [[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []] + ) + elif last_request_type == "text": + safe = [] # direct text embeddings expectation is a flat list + else: + safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []] + rep_socket.send(msgpack.packb(safe, use_single_float=True)) except Exception: pass else: