Compare commits
1 Commits
feature/sk
...
clean-stat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
909d3cc6a8 |
@@ -95,6 +95,8 @@ def create_hnsw_embedding_server(
|
||||
passage_sources.append(source_copy)
|
||||
|
||||
passages = PassageManager(passage_sources)
|
||||
# Use index dimensions from metadata for shaping fallback responses
|
||||
embedding_dim: int = int(meta.get("dimensions", 0))
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
@@ -109,6 +111,9 @@ def create_hnsw_embedding_server(
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
|
||||
# Track last request type for safe fallback responses on exceptions
|
||||
last_request_type = "unknown" # one of: 'text', 'distance', 'embedding', 'unknown'
|
||||
last_request_length = 0
|
||||
while True:
|
||||
try:
|
||||
message_bytes = socket.recv()
|
||||
@@ -121,6 +126,8 @@ def create_hnsw_embedding_server(
|
||||
if isinstance(request_payload, list) and len(request_payload) > 0:
|
||||
# Check if this is a direct text request (list of strings)
|
||||
if all(isinstance(item, str) for item in request_payload):
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request_payload)
|
||||
logger.info(
|
||||
f"Processing direct text embedding request for {len(request_payload)} texts in {embedding_mode} mode"
|
||||
)
|
||||
@@ -145,43 +152,66 @@ def create_hnsw_embedding_server(
|
||||
):
|
||||
node_ids = request_payload[0]
|
||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||
last_request_type = "distance"
|
||||
last_request_length = len(node_ids)
|
||||
|
||||
logger.debug("Distance calculation request received")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
|
||||
# Get embeddings for node IDs
|
||||
texts = []
|
||||
for nid in node_ids:
|
||||
# Get embeddings for node IDs, tolerate missing IDs
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
texts.append(txt)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
# Prepare full-length response distances with safe fallbacks
|
||||
large_distance = 1e9
|
||||
response_distances = [large_distance] * len(node_ids)
|
||||
|
||||
if texts:
|
||||
try:
|
||||
# Process embeddings only for found indices
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Calculate distances for found embeddings only
|
||||
if distance_metric == "l2":
|
||||
partial_distances = np.sum(
|
||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||
)
|
||||
else: # mips or cosine
|
||||
partial_distances = -np.dot(embeddings, query_vector)
|
||||
|
||||
# Place computed distances back into the full response array
|
||||
for pos, dval in zip(
|
||||
found_indices, partial_distances.flatten().tolist()
|
||||
):
|
||||
response_distances[pos] = float(dval)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Distance computation error, falling back to large distances: {e}"
|
||||
)
|
||||
|
||||
# Always reply with exactly len(node_ids) distances
|
||||
response_bytes = msgpack.packb([response_distances], use_single_float=True)
|
||||
logger.debug(
|
||||
f"Sending distance response with {len(response_distances)} distances (found={len(found_indices)})"
|
||||
)
|
||||
|
||||
# Calculate distances
|
||||
if distance_metric == "l2":
|
||||
distances = np.sum(
|
||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||
)
|
||||
else: # mips or cosine
|
||||
distances = -np.dot(embeddings, query_vector)
|
||||
|
||||
response_payload = distances.flatten().tolist()
|
||||
response_bytes = msgpack.packb([response_payload], use_single_float=True)
|
||||
logger.debug(f"Sending distance response with {len(distances)} distances")
|
||||
|
||||
socket.send(response_bytes)
|
||||
e2e_end = time.time()
|
||||
logger.info(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
@@ -201,40 +231,61 @@ def create_hnsw_embedding_server(
|
||||
|
||||
node_ids = request_payload[0]
|
||||
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
||||
last_request_type = "embedding"
|
||||
last_request_length = len(node_ids)
|
||||
|
||||
# Look up texts by node IDs
|
||||
texts = []
|
||||
for nid in node_ids:
|
||||
# Allocate output buffer (B, D) and fill with zeros for robustness
|
||||
if embedding_dim <= 0:
|
||||
logger.error("Embedding dimension unknown; cannot serve embedding request")
|
||||
dims = [0, 0]
|
||||
data = []
|
||||
else:
|
||||
dims = [len(node_ids), embedding_dim]
|
||||
data = [0.0] * (dims[0] * dims[1])
|
||||
|
||||
# Look up texts by node IDs; compute embeddings where available
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_data = passages.get_passage(str(nid))
|
||||
txt = passage_data["text"]
|
||||
if not txt:
|
||||
raise RuntimeError(f"FATAL: Empty text for passage ID {nid}")
|
||||
texts.append(txt)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {nid}")
|
||||
except KeyError:
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
logger.error(f"Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
if texts:
|
||||
try:
|
||||
# Process embeddings for found texts only
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Serialization and response
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
logger.error(
|
||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||
)
|
||||
raise AssertionError()
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
logger.error(
|
||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||
)
|
||||
dims = [0, embedding_dim]
|
||||
data = []
|
||||
else:
|
||||
# Copy computed embeddings into the correct positions
|
||||
emb_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
flat = emb_f32.flatten().tolist()
|
||||
for j, pos in enumerate(found_indices):
|
||||
start = pos * embedding_dim
|
||||
end = start + embedding_dim
|
||||
data[start:end] = flat[j * embedding_dim : (j + 1) * embedding_dim]
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding computation error, returning zeros: {e}")
|
||||
|
||||
hidden_contiguous_f32 = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
response_payload = [
|
||||
list(hidden_contiguous_f32.shape),
|
||||
hidden_contiguous_f32.flatten().tolist(),
|
||||
]
|
||||
response_payload = [dims, data]
|
||||
response_bytes = msgpack.packb(response_payload, use_single_float=True)
|
||||
|
||||
socket.send(response_bytes)
|
||||
@@ -249,7 +300,22 @@ def create_hnsw_embedding_server(
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
socket.send(msgpack.packb([[], []]))
|
||||
# Fallback to a safe, minimal-structure response to avoid client crashes
|
||||
if last_request_type == "distance":
|
||||
# Return a vector of large distances with the expected length
|
||||
fallback_len = max(0, int(last_request_length))
|
||||
large_distance = 1e9
|
||||
safe_response = [[large_distance] * fallback_len]
|
||||
elif last_request_type == "embedding":
|
||||
# Return an empty embedding block with known dimension if available
|
||||
if embedding_dim > 0:
|
||||
safe_response = [[0, embedding_dim], []]
|
||||
else:
|
||||
safe_response = [[0, 0], []]
|
||||
else:
|
||||
# Unknown request type: default to empty embedding structure
|
||||
safe_response = [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||
socket.send(msgpack.packb(safe_response, use_single_float=True))
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
|
||||
Reference in New Issue
Block a user