refactor: embedding server

This commit is contained in:
Andy Lee
2025-11-19 06:50:39 +00:00
parent 469dce0045
commit 29ef3c95dc

View File

@@ -143,8 +143,6 @@ def create_hnsw_embedding_server(
pass pass
return str(nid) return str(nid)
# (legacy ZMQ thread removed; using shutdown-capable server only)
def zmq_server_thread_with_shutdown(shutdown_event): def zmq_server_thread_with_shutdown(shutdown_event):
"""ZMQ server thread that respects shutdown signal. """ZMQ server thread that respects shutdown signal.
@@ -158,35 +156,42 @@ def create_hnsw_embedding_server(
rep_socket.bind(f"tcp://*:{zmq_port}") rep_socket.bind(f"tcp://*:{zmq_port}")
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}") logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
rep_socket.setsockopt(zmq.RCVTIMEO, 1000) rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
# Keep sends from blocking during shutdown; fail fast and drop on close
rep_socket.setsockopt(zmq.SNDTIMEO, 1000) rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
rep_socket.setsockopt(zmq.LINGER, 0) rep_socket.setsockopt(zmq.LINGER, 0)
# Track last request type/length for shape-correct fallbacks last_request_type = "unknown"
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
last_request_length = 0 last_request_length = 0
try: def _build_safe_fallback():
while not shutdown_event.is_set(): if last_request_type == "distance":
try: large_distance = 1e9
e2e_start = time.time() fallback_len = max(0, int(last_request_length))
logger.debug("🔍 Waiting for ZMQ message...") return [[large_distance] * fallback_len]
request_bytes = rep_socket.recv() if last_request_type == "embedding":
bsz = max(0, int(last_request_length))
dim = max(0, int(embedding_dim))
if dim > 0:
return [[bsz, dim], [0.0] * (bsz * dim)]
return [[0, 0], []]
if last_request_type == "text":
return []
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
# Rest of the processing logic (same as original) def _handle_request(request):
request = msgpack.unpackb(request_bytes) nonlocal last_request_type, last_request_length
if len(request) == 1 and request[0] == "__QUERY_MODEL__": # Model query
response_bytes = msgpack.packb([model_name]) if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__":
rep_socket.send(response_bytes) rep_socket.send(msgpack.packb([model_name]))
continue return
# Handle direct text embedding request # Direct text embedding
if ( if (
isinstance(request, list) isinstance(request, list)
and request and request
and all(isinstance(item, str) for item in request) and all(isinstance(item, str) for item in request)
): ):
e2e_start = time.time()
last_request_type = "text" last_request_type = "text"
last_request_length = len(request) last_request_length = len(request)
embeddings = compute_embeddings( embeddings = compute_embeddings(
@@ -197,18 +202,20 @@ def create_hnsw_embedding_server(
) )
rep_socket.send(msgpack.packb(embeddings.tolist())) rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") logger.info(
continue f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
)
return
# Handle distance calculation request: [[ids], [query_vector]] # Distance calculation: [[ids], [query_vector]]
if ( if (
isinstance(request, list) isinstance(request, list)
and len(request) == 2 and len(request) == 2
and isinstance(request[0], list) and isinstance(request[0], list)
and isinstance(request[1], list) and isinstance(request[1], list)
): ):
e2e_start = time.time()
node_ids = request[0] node_ids = request[0]
# Handle nested [[ids]] shape defensively
if len(node_ids) == 1 and isinstance(node_ids[0], list): if len(node_ids) == 1 and isinstance(node_ids[0], list):
node_ids = node_ids[0] node_ids = node_ids[0]
query_vector = np.array(request[1], dtype=np.float32) query_vector = np.array(request[1], dtype=np.float32)
@@ -219,7 +226,6 @@ def create_hnsw_embedding_server(
logger.debug(f" Node IDs: {node_ids}") logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}") logger.debug(f" Query vector dim: {len(query_vector)}")
# Gather texts for found ids
texts: list[str] = [] texts: list[str] = []
found_indices: list[int] = [] found_indices: list[int] = []
for idx, nid in enumerate(node_ids): for idx, nid in enumerate(node_ids):
@@ -234,10 +240,9 @@ def create_hnsw_embedding_server(
logger.error(f"Empty text for passage ID {passage_id}") logger.error(f"Empty text for passage ID {passage_id}")
except KeyError: except KeyError:
logger.error(f"Passage ID {nid} not found") logger.error(f"Passage ID {nid} not found")
except Exception as e: except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {exc}")
# Prepare full-length response with large sentinel values
large_distance = 1e9 large_distance = 1e9
response_distances = [large_distance] * len(node_ids) response_distances = [large_distance] * len(node_ids)
@@ -256,21 +261,24 @@ def create_hnsw_embedding_server(
partial = np.sum( partial = np.sum(
np.square(embeddings - query_vector.reshape(1, -1)), axis=1 np.square(embeddings - query_vector.reshape(1, -1)), axis=1
) )
else: # mips or cosine else:
partial = -np.dot(embeddings, query_vector) partial = -np.dot(embeddings, query_vector)
for pos, dval in zip(found_indices, partial.flatten().tolist()): for pos, dval in zip(found_indices, partial.flatten().tolist()):
response_distances[pos] = float(dval) response_distances[pos] = float(dval)
except Exception as e: except Exception as exc:
logger.error(f"Distance computation error, using sentinels: {e}") logger.error(f"Distance computation error, using sentinels: {exc}")
# Send response in expected shape [[distances]] rep_socket.send(
rep_socket.send(msgpack.packb([response_distances], use_single_float=True)) msgpack.packb([response_distances], use_single_float=True)
)
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") logger.info(
continue f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
return
# Fallback: treat as embedding-by-id request # Embedding-by-id fallback
if ( if (
isinstance(request, list) isinstance(request, list)
and len(request) == 1 and len(request) == 1
@@ -281,11 +289,12 @@ def create_hnsw_embedding_server(
node_ids = request node_ids = request
else: else:
node_ids = [] node_ids = []
e2e_start = time.time()
last_request_type = "embedding" last_request_type = "embedding"
last_request_length = len(node_ids) last_request_length = len(node_ids)
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch") logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
# Preallocate zero-filled flat data for robustness
if embedding_dim <= 0: if embedding_dim <= 0:
dims = [0, 0] dims = [0, 0]
flat_data: list[float] = [] flat_data: list[float] = []
@@ -293,9 +302,8 @@ def create_hnsw_embedding_server(
dims = [len(node_ids), embedding_dim] dims = [len(node_ids), embedding_dim]
flat_data = [0.0] * (dims[0] * dims[1]) flat_data = [0.0] * (dims[0] * dims[1])
# Collect texts for found ids texts = []
texts: list[str] = [] found_indices = []
found_indices: list[int] = []
for idx, nid in enumerate(node_ids): for idx, nid in enumerate(node_ids):
try: try:
passage_id = _map_node_id(nid) passage_id = _map_node_id(nid)
@@ -308,8 +316,8 @@ def create_hnsw_embedding_server(
logger.error(f"Empty text for passage ID {passage_id}") logger.error(f"Empty text for passage ID {passage_id}")
except KeyError: except KeyError:
logger.error(f"Passage with ID {nid} not found") logger.error(f"Passage with ID {nid} not found")
except Exception as e: except Exception as exc:
logger.error(f"Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {exc}")
if texts: if texts:
try: try:
@@ -339,44 +347,54 @@ def create_hnsw_embedding_server(
flat_data[start:end] = flat[ flat_data[start:end] = flat[
j * embedding_dim : (j + 1) * embedding_dim j * embedding_dim : (j + 1) * embedding_dim
] ]
except Exception as e: except Exception as exc:
logger.error(f"Embedding computation error, returning zeros: {e}") logger.error(f"Embedding computation error, returning zeros: {exc}")
response_payload = [dims, flat_data] response_payload = [dims, flat_data]
response_bytes = msgpack.packb(response_payload, use_single_float=True) rep_socket.send(
msgpack.packb(response_payload, use_single_float=True)
rep_socket.send(response_bytes) )
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s") logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
except zmq.Again:
# Timeout - check shutdown_event and continue
continue
except Exception as e:
if not shutdown_event.is_set():
logger.error(f"Error in ZMQ server loop: {e}")
# Shape-correct fallback
try: try:
if last_request_type == "distance": while not shutdown_event.is_set():
large_distance = 1e9 try:
fallback_len = max(0, int(last_request_length)) logger.debug("🔍 Waiting for ZMQ message...")
safe = [[large_distance] * fallback_len] request_bytes = rep_socket.recv()
elif last_request_type == "embedding": except zmq.Again:
bsz = max(0, int(last_request_length)) continue
dim = max(0, int(embedding_dim))
safe = ( try:
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []] request = msgpack.unpackb(request_bytes)
) except Exception as exc:
elif last_request_type == "text": if shutdown_event.is_set():
safe = [] # direct text embeddings expectation is a flat list
else:
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
rep_socket.send(msgpack.packb(safe, use_single_float=True))
except Exception:
pass
else:
logger.info("Shutdown in progress, ignoring ZMQ error") logger.info("Shutdown in progress, ignoring ZMQ error")
break break
logger.error(f"Error unpacking ZMQ message: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(
msgpack.packb(safe, use_single_float=True)
)
except Exception:
pass
continue
try:
_handle_request(request)
except Exception as exc:
if shutdown_event.is_set():
logger.info("Shutdown in progress, ignoring ZMQ error")
break
logger.error(f"Error in ZMQ server loop: {exc}")
try:
safe = _build_safe_fallback()
rep_socket.send(
msgpack.packb(safe, use_single_float=True)
)
except Exception:
pass
finally: finally:
try: try:
rep_socket.close(0) rep_socket.close(0)