refactor: embedding server
This commit is contained in:
@@ -177,108 +177,93 @@ def create_hnsw_embedding_server(
|
||||
return []
|
||||
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||
|
||||
def _handle_request(request):
|
||||
def _handle_text_embedding(request: list[str]) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
# Model query
|
||||
if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||
rep_socket.send(msgpack.packb([model_name]))
|
||||
return
|
||||
e2e_start = time.time()
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
embeddings = compute_embeddings(
|
||||
request,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||
e2e_end = time.time()
|
||||
logger.info(
|
||||
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
|
||||
# Direct text embedding
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and request
|
||||
and all(isinstance(item, str) for item in request)
|
||||
):
|
||||
e2e_start = time.time()
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
embeddings = compute_embeddings(
|
||||
request,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||
e2e_end = time.time()
|
||||
logger.info(
|
||||
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
return
|
||||
def _handle_distance_request(request: list[Any]) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
# Distance calculation: [[ids], [query_vector]]
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 2
|
||||
and isinstance(request[0], list)
|
||||
and isinstance(request[1], list)
|
||||
):
|
||||
e2e_start = time.time()
|
||||
node_ids = request[0]
|
||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||
node_ids = node_ids[0]
|
||||
query_vector = np.array(request[1], dtype=np.float32)
|
||||
last_request_type = "distance"
|
||||
last_request_length = len(node_ids)
|
||||
e2e_start = time.time()
|
||||
node_ids = request[0]
|
||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||
node_ids = node_ids[0]
|
||||
query_vector = np.array(request[1], dtype=np.float32)
|
||||
last_request_type = "distance"
|
||||
last_request_length = len(node_ids)
|
||||
|
||||
logger.debug("Distance calculation request received")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
logger.debug("Distance calculation request received")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
except Exception as exc:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
passage_data = passages.get_passage(passage_id)
|
||||
txt = passage_data.get("text", "")
|
||||
if isinstance(txt, str) and len(txt) > 0:
|
||||
texts.append(txt)
|
||||
found_indices.append(idx)
|
||||
else:
|
||||
logger.error(f"Empty text for passage ID {passage_id}")
|
||||
except KeyError:
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
except Exception as exc:
|
||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||
|
||||
large_distance = 1e9
|
||||
response_distances = [large_distance] * len(node_ids)
|
||||
large_distance = 1e9
|
||||
response_distances = [large_distance] * len(node_ids)
|
||||
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
if texts:
|
||||
try:
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
provider_options=PROVIDER_OPTIONS,
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
if distance_metric == "l2":
|
||||
partial = np.sum(
|
||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||
)
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
if distance_metric == "l2":
|
||||
partial = np.sum(
|
||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||
)
|
||||
else:
|
||||
partial = -np.dot(embeddings, query_vector)
|
||||
else:
|
||||
partial = -np.dot(embeddings, query_vector)
|
||||
|
||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||
response_distances[pos] = float(dval)
|
||||
except Exception as exc:
|
||||
logger.error(f"Distance computation error, using sentinels: {exc}")
|
||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||
response_distances[pos] = float(dval)
|
||||
except Exception as exc:
|
||||
logger.error(f"Distance computation error, using sentinels: {exc}")
|
||||
|
||||
rep_socket.send(
|
||||
msgpack.packb([response_distances], use_single_float=True)
|
||||
)
|
||||
e2e_end = time.time()
|
||||
logger.info(
|
||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
return
|
||||
rep_socket.send(
|
||||
msgpack.packb([response_distances], use_single_float=True)
|
||||
)
|
||||
e2e_end = time.time()
|
||||
logger.info(
|
||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
|
||||
def _handle_embedding_by_id(request: Any) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
# Embedding-by-id fallback
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 1
|
||||
@@ -302,8 +287,8 @@ def create_hnsw_embedding_server(
|
||||
dims = [len(node_ids), embedding_dim]
|
||||
flat_data = [0.0] * (dims[0] * dims[1])
|
||||
|
||||
texts = []
|
||||
found_indices = []
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
@@ -382,7 +367,31 @@ def create_hnsw_embedding_server(
|
||||
continue
|
||||
|
||||
try:
|
||||
_handle_request(request)
|
||||
# Model query
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 1
|
||||
and request[0] == "__QUERY_MODEL__"
|
||||
):
|
||||
rep_socket.send(msgpack.packb([model_name]))
|
||||
# Direct text embedding
|
||||
elif (
|
||||
isinstance(request, list)
|
||||
and request
|
||||
and all(isinstance(item, str) for item in request)
|
||||
):
|
||||
_handle_text_embedding(request)
|
||||
# Distance calculation: [[ids], [query_vector]]
|
||||
elif (
|
||||
isinstance(request, list)
|
||||
and len(request) == 2
|
||||
and isinstance(request[0], list)
|
||||
and isinstance(request[1], list)
|
||||
):
|
||||
_handle_distance_request(request)
|
||||
# Embedding-by-id fallback
|
||||
else:
|
||||
_handle_embedding_by_id(request)
|
||||
except Exception as exc:
|
||||
if shutdown_event.is_set():
|
||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||
|
||||
Reference in New Issue
Block a user