refactor: embedding server

This commit is contained in:
Andy Lee
2025-11-19 06:54:10 +00:00
parent 29ef3c95dc
commit 66c6aad3e4

View File

@@ -177,20 +177,9 @@ def create_hnsw_embedding_server(
return [] return []
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []] return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
def _handle_request(request): def _handle_text_embedding(request: list[str]) -> None:
nonlocal last_request_type, last_request_length nonlocal last_request_type, last_request_length
# Model query
if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__":
rep_socket.send(msgpack.packb([model_name]))
return
# Direct text embedding
if (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
e2e_start = time.time() e2e_start = time.time()
last_request_type = "text" last_request_type = "text"
last_request_length = len(request) last_request_length = len(request)
@@ -205,15 +194,10 @@ def create_hnsw_embedding_server(
logger.info( logger.info(
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s" f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
) )
return
# Distance calculation: [[ids], [query_vector]] def _handle_distance_request(request: list[Any]) -> None:
if ( nonlocal last_request_type, last_request_length
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
e2e_start = time.time() e2e_start = time.time()
node_ids = request[0] node_ids = request[0]
if len(node_ids) == 1 and isinstance(node_ids[0], list): if len(node_ids) == 1 and isinstance(node_ids[0], list):
@@ -276,9 +260,10 @@ def create_hnsw_embedding_server(
logger.info( logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s" f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
) )
return
# Embedding-by-id fallback def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if ( if (
isinstance(request, list) isinstance(request, list)
and len(request) == 1 and len(request) == 1
@@ -302,8 +287,8 @@ def create_hnsw_embedding_server(
dims = [len(node_ids), embedding_dim] dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1]) flat_data = [0.0] * (dims[0] * dims[1])
texts = [] texts: list[str] = []
found_indices = [] found_indices: list[int] = []
for idx, nid in enumerate(node_ids): for idx, nid in enumerate(node_ids):
try: try:
passage_id = _map_node_id(nid) passage_id = _map_node_id(nid)
@@ -382,7 +367,31 @@ def create_hnsw_embedding_server(
continue continue
try: try:
_handle_request(request) # Model query
if (
isinstance(request, list)
and len(request) == 1
and request[0] == "__QUERY_MODEL__"
):
rep_socket.send(msgpack.packb([model_name]))
# Direct text embedding
elif (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
_handle_text_embedding(request)
# Distance calculation: [[ids], [query_vector]]
elif (
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
_handle_distance_request(request)
# Embedding-by-id fallback
else:
_handle_embedding_by_id(request)
except Exception as exc: except Exception as exc:
if shutdown_event.is_set(): if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error") logger.info("Shutdown in progress, ignoring ZMQ error")