fix: faster embed

This commit is contained in:
Andy Lee
2025-11-24 05:30:11 +00:00
parent 66c6aad3e4
commit 36c44b8806
4 changed files with 110 additions and 95 deletions

View File

@@ -191,9 +191,7 @@ def create_hnsw_embedding_server(
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
)
logger.info(f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_distance_request(request: list[Any]) -> None:
nonlocal last_request_type, last_request_length
@@ -253,22 +251,14 @@ def create_hnsw_embedding_server(
except Exception as exc:
logger.error(f"Distance computation error, using sentinels: {exc}")
rep_socket.send(
msgpack.packb([response_distances], use_single_float=True)
)
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
e2e_end = time.time()
logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
def _handle_embedding_by_id(request: Any) -> None:
nonlocal last_request_type, last_request_length
if (
isinstance(request, list)
and len(request) == 1
and isinstance(request[0], list)
):
if isinstance(request, list) and len(request) == 1 and isinstance(request[0], list):
node_ids = request[0]
elif isinstance(request, list):
node_ids = request
@@ -336,11 +326,9 @@ def create_hnsw_embedding_server(
logger.error(f"Embedding computation error, returning zeros: {exc}")
response_payload = [dims, flat_data]
rep_socket.send(
msgpack.packb(response_payload, use_single_float=True)
)
rep_socket.send(msgpack.packb(response_payload, use_single_float=True))
e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
logger.info(f"⏱️ Fallback Embed by Id E2E time: {e2e_end - e2e_start:.6f}s")
try:
while not shutdown_event.is_set():
@@ -359,9 +347,7 @@ def create_hnsw_embedding_server(
logger.error(f"Error unpacking ZMQ message: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(
msgpack.packb(safe, use_single_float=True)
)
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
continue
@@ -399,9 +385,7 @@ def create_hnsw_embedding_server(
logger.error(f"Error in ZMQ server loop: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(
msgpack.packb(safe, use_single_float=True)
)
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
finally: