refactor: embedding server

This commit is contained in:
Andy Lee
2025-11-19 06:54:10 +00:00
parent 29ef3c95dc
commit 66c6aad3e4

View File

@@ -177,20 +177,9 @@ def create_hnsw_embedding_server(
return []
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
def _handle_request(request):
def _handle_text_embedding(request: list[str]) -> None:
nonlocal last_request_type, last_request_length
# Model query
if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__":
rep_socket.send(msgpack.packb([model_name]))
return
# Direct text embedding
if (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
e2e_start = time.time()
last_request_type = "text"
last_request_length = len(request)
@@ -205,15 +194,10 @@ def create_hnsw_embedding_server(
logger.info(
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
)
return
# Distance calculation: [[ids], [query_vector]]
if (
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
def _handle_distance_request(request: list[Any]) -> None:
nonlocal last_request_type, last_request_length
e2e_start = time.time()
node_ids = request[0]
if len(node_ids) == 1 and isinstance(node_ids[0], list):
@@ -276,9 +260,10 @@ def create_hnsw_embedding_server(
logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
return
# Embedding-by-id fallback
def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if (
isinstance(request, list)
and len(request) == 1
@@ -302,8 +287,8 @@ def create_hnsw_embedding_server(
dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1])
texts = []
found_indices = []
texts: list[str] = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids):
try:
passage_id = _map_node_id(nid)
@@ -382,7 +367,31 @@ def create_hnsw_embedding_server(
continue
try:
_handle_request(request)
# Model query
if (
isinstance(request, list)
and len(request) == 1
and request[0] == "__QUERY_MODEL__"
):
rep_socket.send(msgpack.packb([model_name]))
# Direct text embedding
elif (
isinstance(request, list)
and request
and all(isinstance(item, str) for item in request)
):
_handle_text_embedding(request)
# Distance calculation: [[ids], [query_vector]]
elif (
isinstance(request, list)
and len(request) == 2
and isinstance(request[0], list)
and isinstance(request[1], list)
):
_handle_distance_request(request)
# Embedding-by-id fallback
else:
_handle_embedding_by_id(request)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error")