From 66c6aad3e4cea9d9c3d7a2eb3f3cd659186181e2 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Wed, 19 Nov 2025 06:54:10 +0000 Subject: [PATCH] refactor: embedding server --- .../hnsw_embedding_server.py | 195 +++++++++--------- 1 file changed, 102 insertions(+), 93 deletions(-) diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index dac48e5..ae58223 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -177,108 +177,93 @@ def create_hnsw_embedding_server( return [] return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []] - def _handle_request(request): + def _handle_text_embedding(request: list[str]) -> None: nonlocal last_request_type, last_request_length - # Model query - if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__": - rep_socket.send(msgpack.packb([model_name])) - return + e2e_start = time.time() + last_request_type = "text" + last_request_length = len(request) + embeddings = compute_embeddings( + request, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) + rep_socket.send(msgpack.packb(embeddings.tolist())) + e2e_end = time.time() + logger.info( + f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s" + ) - # Direct text embedding - if ( - isinstance(request, list) - and request - and all(isinstance(item, str) for item in request) - ): - e2e_start = time.time() - last_request_type = "text" - last_request_length = len(request) - embeddings = compute_embeddings( - request, - model_name, - mode=embedding_mode, - provider_options=PROVIDER_OPTIONS, - ) - rep_socket.send(msgpack.packb(embeddings.tolist())) - e2e_end = time.time() - logger.info( - f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s" - ) - return + def _handle_distance_request(request: list[Any]) -> None: + nonlocal last_request_type, last_request_length - # Distance calculation: [[ids], [query_vector]] - if ( - isinstance(request, list) - and len(request) == 2 - and isinstance(request[0], list) - and isinstance(request[1], list) - ): - e2e_start = time.time() - node_ids = request[0] - if len(node_ids) == 1 and isinstance(node_ids[0], list): - node_ids = node_ids[0] - query_vector = np.array(request[1], dtype=np.float32) - last_request_type = "distance" - last_request_length = len(node_ids) + e2e_start = time.time() + node_ids = request[0] + if len(node_ids) == 1 and isinstance(node_ids[0], list): + node_ids = node_ids[0] + query_vector = np.array(request[1], dtype=np.float32) + last_request_type = "distance" + last_request_length = len(node_ids) - logger.debug("Distance calculation request received") - logger.debug(f" Node IDs: {node_ids}") - logger.debug(f" Query vector dim: {len(query_vector)}") + logger.debug("Distance calculation request received") + logger.debug(f" Node IDs: {node_ids}") + logger.debug(f" Query vector dim: {len(query_vector)}") - texts: list[str] = [] - found_indices: list[int] = [] - for idx, nid in enumerate(node_ids): - try: - passage_id = _map_node_id(nid) - passage_data = passages.get_passage(passage_id) - txt = passage_data.get("text", "") - if isinstance(txt, str) and len(txt) > 0: - texts.append(txt) - found_indices.append(idx) - else: - logger.error(f"Empty text for passage ID {passage_id}") - except KeyError: - logger.error(f"Passage ID {nid} not found") - except Exception as exc: - logger.error(f"Exception looking up passage ID {nid}: {exc}") + texts: list[str] = [] + found_indices: list[int] = [] + for idx, nid in enumerate(node_ids): + try: + passage_id = _map_node_id(nid) + passage_data = passages.get_passage(passage_id) + txt = passage_data.get("text", "") + if isinstance(txt, str) and len(txt) > 0: + texts.append(txt) + found_indices.append(idx) + else: + logger.error(f"Empty text for passage ID {passage_id}") + except KeyError: + logger.error(f"Passage ID {nid} not found") + except Exception as exc: + logger.error(f"Exception looking up passage ID {nid}: {exc}") - large_distance = 1e9 - response_distances = [large_distance] * len(node_ids) + large_distance = 1e9 + response_distances = [large_distance] * len(node_ids) - if texts: - try: - embeddings = compute_embeddings( - texts, - model_name, - mode=embedding_mode, - provider_options=PROVIDER_OPTIONS, + if texts: + try: + embeddings = compute_embeddings( + texts, + model_name, + mode=embedding_mode, + provider_options=PROVIDER_OPTIONS, + ) + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + ) + if distance_metric == "l2": + partial = np.sum( + np.square(embeddings - query_vector.reshape(1, -1)), axis=1 ) - logger.info( - f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" - ) - if distance_metric == "l2": - partial = np.sum( - np.square(embeddings - query_vector.reshape(1, -1)), axis=1 - ) - else: - partial = -np.dot(embeddings, query_vector) + else: + partial = -np.dot(embeddings, query_vector) - for pos, dval in zip(found_indices, partial.flatten().tolist()): - response_distances[pos] = float(dval) - except Exception as exc: - logger.error(f"Distance computation error, using sentinels: {exc}") + for pos, dval in zip(found_indices, partial.flatten().tolist()): + response_distances[pos] = float(dval) + except Exception as exc: + logger.error(f"Distance computation error, using sentinels: {exc}") - rep_socket.send( - msgpack.packb([response_distances], use_single_float=True) - ) - e2e_end = time.time() - logger.info( - f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s" - ) - return + rep_socket.send( + msgpack.packb([response_distances], use_single_float=True) + ) + e2e_end = time.time() + logger.info( + f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s" + ) + + def _handle_embedding_by_id(request: Any) -> None: + nonlocal last_request_type, last_request_length - # Embedding-by-id fallback if ( isinstance(request, list) and len(request) == 1 @@ -302,8 +287,8 @@ def create_hnsw_embedding_server( dims = [len(node_ids), embedding_dim] flat_data = [0.0] * (dims[0] * dims[1]) - texts = [] - found_indices = [] + texts: list[str] = [] + found_indices: list[int] = [] for idx, nid in enumerate(node_ids): try: passage_id = _map_node_id(nid) @@ -382,7 +367,31 @@ def create_hnsw_embedding_server( continue try: - _handle_request(request) + # Model query + if ( + isinstance(request, list) + and len(request) == 1 + and request[0] == "__QUERY_MODEL__" + ): + rep_socket.send(msgpack.packb([model_name])) + # Direct text embedding + elif ( + isinstance(request, list) + and request + and all(isinstance(item, str) for item in request) + ): + _handle_text_embedding(request) + # Distance calculation: [[ids], [query_vector]] + elif ( + isinstance(request, list) + and len(request) == 2 + and isinstance(request[0], list) + and isinstance(request[1], list) + ): + _handle_distance_request(request) + # Embedding-by-id fallback + else: + _handle_embedding_by_id(request) except Exception as exc: if shutdown_event.is_set(): logger.info("Shutdown in progress, ignoring ZMQ error")