refactor: embedding server
This commit is contained in:
@@ -143,8 +143,6 @@ def create_hnsw_embedding_server(
|
|||||||
pass
|
pass
|
||||||
return str(nid)
|
return str(nid)
|
||||||
|
|
||||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
|
||||||
|
|
||||||
def zmq_server_thread_with_shutdown(shutdown_event):
|
def zmq_server_thread_with_shutdown(shutdown_event):
|
||||||
"""ZMQ server thread that respects shutdown signal.
|
"""ZMQ server thread that respects shutdown signal.
|
||||||
|
|
||||||
@@ -158,225 +156,245 @@ def create_hnsw_embedding_server(
|
|||||||
rep_socket.bind(f"tcp://*:{zmq_port}")
|
rep_socket.bind(f"tcp://*:{zmq_port}")
|
||||||
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ REP server listening on port {zmq_port}")
|
||||||
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
rep_socket.setsockopt(zmq.RCVTIMEO, 1000)
|
||||||
# Keep sends from blocking during shutdown; fail fast and drop on close
|
|
||||||
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
rep_socket.setsockopt(zmq.SNDTIMEO, 1000)
|
||||||
rep_socket.setsockopt(zmq.LINGER, 0)
|
rep_socket.setsockopt(zmq.LINGER, 0)
|
||||||
|
|
||||||
# Track last request type/length for shape-correct fallbacks
|
last_request_type = "unknown"
|
||||||
last_request_type = "unknown" # 'text' | 'distance' | 'embedding' | 'unknown'
|
|
||||||
last_request_length = 0
|
last_request_length = 0
|
||||||
|
|
||||||
try:
|
def _build_safe_fallback():
|
||||||
while not shutdown_event.is_set():
|
if last_request_type == "distance":
|
||||||
try:
|
large_distance = 1e9
|
||||||
e2e_start = time.time()
|
fallback_len = max(0, int(last_request_length))
|
||||||
logger.debug("🔍 Waiting for ZMQ message...")
|
return [[large_distance] * fallback_len]
|
||||||
request_bytes = rep_socket.recv()
|
if last_request_type == "embedding":
|
||||||
|
bsz = max(0, int(last_request_length))
|
||||||
|
dim = max(0, int(embedding_dim))
|
||||||
|
if dim > 0:
|
||||||
|
return [[bsz, dim], [0.0] * (bsz * dim)]
|
||||||
|
return [[0, 0], []]
|
||||||
|
if last_request_type == "text":
|
||||||
|
return []
|
||||||
|
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
|
|
||||||
# Rest of the processing logic (same as original)
|
def _handle_request(request):
|
||||||
request = msgpack.unpackb(request_bytes)
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
if len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
# Model query
|
||||||
response_bytes = msgpack.packb([model_name])
|
if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||||
rep_socket.send(response_bytes)
|
rep_socket.send(msgpack.packb([model_name]))
|
||||||
continue
|
return
|
||||||
|
|
||||||
# Handle direct text embedding request
|
# Direct text embedding
|
||||||
if (
|
if (
|
||||||
isinstance(request, list)
|
isinstance(request, list)
|
||||||
and request
|
and request
|
||||||
and all(isinstance(item, str) for item in request)
|
and all(isinstance(item, str) for item in request)
|
||||||
):
|
):
|
||||||
last_request_type = "text"
|
e2e_start = time.time()
|
||||||
last_request_length = len(request)
|
last_request_type = "text"
|
||||||
|
last_request_length = len(request)
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
request,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Distance calculation: [[ids], [query_vector]]
|
||||||
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 2
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
and isinstance(request[1], list)
|
||||||
|
):
|
||||||
|
e2e_start = time.time()
|
||||||
|
node_ids = request[0]
|
||||||
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
|
node_ids = node_ids[0]
|
||||||
|
query_vector = np.array(request[1], dtype=np.float32)
|
||||||
|
last_request_type = "distance"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
|
||||||
|
logger.debug("Distance calculation request received")
|
||||||
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
|
texts: list[str] = []
|
||||||
|
found_indices: list[int] = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
|
try:
|
||||||
|
passage_id = _map_node_id(nid)
|
||||||
|
passage_data = passages.get_passage(passage_id)
|
||||||
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage ID {nid} not found")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||||
|
|
||||||
|
large_distance = 1e9
|
||||||
|
response_distances = [large_distance] * len(node_ids)
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
request,
|
texts,
|
||||||
model_name,
|
model_name,
|
||||||
mode=embedding_mode,
|
mode=embedding_mode,
|
||||||
provider_options=PROVIDER_OPTIONS,
|
provider_options=PROVIDER_OPTIONS,
|
||||||
)
|
)
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
logger.info(
|
||||||
e2e_end = time.time()
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
)
|
||||||
continue
|
if distance_metric == "l2":
|
||||||
|
partial = np.sum(
|
||||||
# Handle distance calculation request: [[ids], [query_vector]]
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
if (
|
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 2
|
|
||||||
and isinstance(request[0], list)
|
|
||||||
and isinstance(request[1], list)
|
|
||||||
):
|
|
||||||
node_ids = request[0]
|
|
||||||
# Handle nested [[ids]] shape defensively
|
|
||||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
|
||||||
node_ids = node_ids[0]
|
|
||||||
query_vector = np.array(request[1], dtype=np.float32)
|
|
||||||
last_request_type = "distance"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
|
||||||
|
|
||||||
# Gather texts for found ids
|
|
||||||
texts: list[str] = []
|
|
||||||
found_indices: list[int] = []
|
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
|
||||||
passage_id = _map_node_id(nid)
|
|
||||||
passage_data = passages.get_passage(passage_id)
|
|
||||||
txt = passage_data.get("text", "")
|
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"Passage ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
|
|
||||||
# Prepare full-length response with large sentinel values
|
|
||||||
large_distance = 1e9
|
|
||||||
response_distances = [large_distance] * len(node_ids)
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
try:
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
||||||
)
|
|
||||||
if distance_metric == "l2":
|
|
||||||
partial = np.sum(
|
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
|
||||||
)
|
|
||||||
else: # mips or cosine
|
|
||||||
partial = -np.dot(embeddings, query_vector)
|
|
||||||
|
|
||||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
|
||||||
response_distances[pos] = float(dval)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Distance computation error, using sentinels: {e}")
|
|
||||||
|
|
||||||
# Send response in expected shape [[distances]]
|
|
||||||
rep_socket.send(msgpack.packb([response_distances], use_single_float=True))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Fallback: treat as embedding-by-id request
|
|
||||||
if (
|
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 1
|
|
||||||
and isinstance(request[0], list)
|
|
||||||
):
|
|
||||||
node_ids = request[0]
|
|
||||||
elif isinstance(request, list):
|
|
||||||
node_ids = request
|
|
||||||
else:
|
|
||||||
node_ids = []
|
|
||||||
last_request_type = "embedding"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
|
||||||
|
|
||||||
# Preallocate zero-filled flat data for robustness
|
|
||||||
if embedding_dim <= 0:
|
|
||||||
dims = [0, 0]
|
|
||||||
flat_data: list[float] = []
|
|
||||||
else:
|
|
||||||
dims = [len(node_ids), embedding_dim]
|
|
||||||
flat_data = [0.0] * (dims[0] * dims[1])
|
|
||||||
|
|
||||||
# Collect texts for found ids
|
|
||||||
texts: list[str] = []
|
|
||||||
found_indices: list[int] = []
|
|
||||||
for idx, nid in enumerate(node_ids):
|
|
||||||
try:
|
|
||||||
passage_id = _map_node_id(nid)
|
|
||||||
passage_data = passages.get_passage(passage_id)
|
|
||||||
txt = passage_data.get("text", "")
|
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
|
||||||
texts.append(txt)
|
|
||||||
found_indices.append(idx)
|
|
||||||
else:
|
|
||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
|
||||||
except KeyError:
|
|
||||||
logger.error(f"Passage with ID {nid} not found")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
|
||||||
|
|
||||||
if texts:
|
|
||||||
try:
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
texts,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
partial = -np.dot(embeddings, query_vector)
|
||||||
|
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||||
logger.error(
|
response_distances[pos] = float(dval)
|
||||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
except Exception as exc:
|
||||||
)
|
logger.error(f"Distance computation error, using sentinels: {exc}")
|
||||||
dims = [0, embedding_dim]
|
|
||||||
flat_data = []
|
|
||||||
else:
|
|
||||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
|
||||||
flat = emb_f32.flatten().tolist()
|
|
||||||
for j, pos in enumerate(found_indices):
|
|
||||||
start = pos * embedding_dim
|
|
||||||
end = start + embedding_dim
|
|
||||||
if end <= len(flat_data):
|
|
||||||
flat_data[start:end] = flat[
|
|
||||||
j * embedding_dim : (j + 1) * embedding_dim
|
|
||||||
]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Embedding computation error, returning zeros: {e}")
|
|
||||||
|
|
||||||
response_payload = [dims, flat_data]
|
rep_socket.send(
|
||||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
msgpack.packb([response_distances], use_single_float=True)
|
||||||
|
)
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
rep_socket.send(response_bytes)
|
# Embedding-by-id fallback
|
||||||
e2e_end = time.time()
|
if (
|
||||||
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
isinstance(request, list)
|
||||||
|
and len(request) == 1
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
):
|
||||||
|
node_ids = request[0]
|
||||||
|
elif isinstance(request, list):
|
||||||
|
node_ids = request
|
||||||
|
else:
|
||||||
|
node_ids = []
|
||||||
|
|
||||||
|
e2e_start = time.time()
|
||||||
|
last_request_type = "embedding"
|
||||||
|
last_request_length = len(node_ids)
|
||||||
|
logger.info(f"ZMQ received {len(node_ids)} node IDs for embedding fetch")
|
||||||
|
|
||||||
|
if embedding_dim <= 0:
|
||||||
|
dims = [0, 0]
|
||||||
|
flat_data: list[float] = []
|
||||||
|
else:
|
||||||
|
dims = [len(node_ids), embedding_dim]
|
||||||
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
found_indices = []
|
||||||
|
for idx, nid in enumerate(node_ids):
|
||||||
|
try:
|
||||||
|
passage_id = _map_node_id(nid)
|
||||||
|
passage_data = passages.get_passage(passage_id)
|
||||||
|
txt = passage_data.get("text", "")
|
||||||
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
|
texts.append(txt)
|
||||||
|
found_indices.append(idx)
|
||||||
|
else:
|
||||||
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
|
except KeyError:
|
||||||
|
logger.error(f"Passage with ID {nid} not found")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||||
|
|
||||||
|
if texts:
|
||||||
|
try:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
logger.error(
|
||||||
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
|
)
|
||||||
|
dims = [0, embedding_dim]
|
||||||
|
flat_data = []
|
||||||
|
else:
|
||||||
|
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
flat = emb_f32.flatten().tolist()
|
||||||
|
for j, pos in enumerate(found_indices):
|
||||||
|
start = pos * embedding_dim
|
||||||
|
end = start + embedding_dim
|
||||||
|
if end <= len(flat_data):
|
||||||
|
flat_data[start:end] = flat[
|
||||||
|
j * embedding_dim : (j + 1) * embedding_dim
|
||||||
|
]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Embedding computation error, returning zeros: {exc}")
|
||||||
|
|
||||||
|
response_payload = [dims, flat_data]
|
||||||
|
rep_socket.send(
|
||||||
|
msgpack.packb(response_payload, use_single_float=True)
|
||||||
|
)
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(f"⏱️ ZMQ E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while not shutdown_event.is_set():
|
||||||
|
try:
|
||||||
|
logger.debug("🔍 Waiting for ZMQ message...")
|
||||||
|
request_bytes = rep_socket.recv()
|
||||||
except zmq.Again:
|
except zmq.Again:
|
||||||
# Timeout - check shutdown_event and continue
|
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
|
||||||
if not shutdown_event.is_set():
|
try:
|
||||||
logger.error(f"Error in ZMQ server loop: {e}")
|
request = msgpack.unpackb(request_bytes)
|
||||||
# Shape-correct fallback
|
except Exception as exc:
|
||||||
try:
|
if shutdown_event.is_set():
|
||||||
if last_request_type == "distance":
|
|
||||||
large_distance = 1e9
|
|
||||||
fallback_len = max(0, int(last_request_length))
|
|
||||||
safe = [[large_distance] * fallback_len]
|
|
||||||
elif last_request_type == "embedding":
|
|
||||||
bsz = max(0, int(last_request_length))
|
|
||||||
dim = max(0, int(embedding_dim))
|
|
||||||
safe = (
|
|
||||||
[[bsz, dim], [0.0] * (bsz * dim)] if dim > 0 else [[0, 0], []]
|
|
||||||
)
|
|
||||||
elif last_request_type == "text":
|
|
||||||
safe = [] # direct text embeddings expectation is a flat list
|
|
||||||
else:
|
|
||||||
safe = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
|
||||||
rep_socket.send(msgpack.packb(safe, use_single_float=True))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
break
|
break
|
||||||
|
logger.error(f"Error unpacking ZMQ message: {exc}")
|
||||||
|
try:
|
||||||
|
safe = _build_safe_fallback()
|
||||||
|
rep_socket.send(
|
||||||
|
msgpack.packb(safe, use_single_float=True)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
_handle_request(request)
|
||||||
|
except Exception as exc:
|
||||||
|
if shutdown_event.is_set():
|
||||||
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
break
|
||||||
|
logger.error(f"Error in ZMQ server loop: {exc}")
|
||||||
|
try:
|
||||||
|
safe = _build_safe_fallback()
|
||||||
|
rep_socket.send(
|
||||||
|
msgpack.packb(safe, use_single_float=True)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
rep_socket.close(0)
|
rep_socket.close(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user