refactor: embedding server
This commit is contained in:
@@ -177,108 +177,93 @@ def create_hnsw_embedding_server(
|
|||||||
return []
|
return []
|
||||||
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
|
|
||||||
def _handle_request(request):
|
def _handle_text_embedding(request: list[str]) -> None:
|
||||||
nonlocal last_request_type, last_request_length
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
# Model query
|
e2e_start = time.time()
|
||||||
if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
last_request_type = "text"
|
||||||
rep_socket.send(msgpack.packb([model_name]))
|
last_request_length = len(request)
|
||||||
return
|
embeddings = compute_embeddings(
|
||||||
|
request,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
|
e2e_end = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
|
)
|
||||||
|
|
||||||
# Direct text embedding
|
def _handle_distance_request(request: list[Any]) -> None:
|
||||||
if (
|
nonlocal last_request_type, last_request_length
|
||||||
isinstance(request, list)
|
|
||||||
and request
|
|
||||||
and all(isinstance(item, str) for item in request)
|
|
||||||
):
|
|
||||||
e2e_start = time.time()
|
|
||||||
last_request_type = "text"
|
|
||||||
last_request_length = len(request)
|
|
||||||
embeddings = compute_embeddings(
|
|
||||||
request,
|
|
||||||
model_name,
|
|
||||||
mode=embedding_mode,
|
|
||||||
provider_options=PROVIDER_OPTIONS,
|
|
||||||
)
|
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
|
||||||
e2e_end = time.time()
|
|
||||||
logger.info(
|
|
||||||
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Distance calculation: [[ids], [query_vector]]
|
e2e_start = time.time()
|
||||||
if (
|
node_ids = request[0]
|
||||||
isinstance(request, list)
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
and len(request) == 2
|
node_ids = node_ids[0]
|
||||||
and isinstance(request[0], list)
|
query_vector = np.array(request[1], dtype=np.float32)
|
||||||
and isinstance(request[1], list)
|
last_request_type = "distance"
|
||||||
):
|
last_request_length = len(node_ids)
|
||||||
e2e_start = time.time()
|
|
||||||
node_ids = request[0]
|
|
||||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
|
||||||
node_ids = node_ids[0]
|
|
||||||
query_vector = np.array(request[1], dtype=np.float32)
|
|
||||||
last_request_type = "distance"
|
|
||||||
last_request_length = len(node_ids)
|
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
logger.debug("Distance calculation request received")
|
||||||
logger.debug(f" Node IDs: {node_ids}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
texts: list[str] = []
|
texts: list[str] = []
|
||||||
found_indices: list[int] = []
|
found_indices: list[int] = []
|
||||||
for idx, nid in enumerate(node_ids):
|
for idx, nid in enumerate(node_ids):
|
||||||
try:
|
try:
|
||||||
passage_id = _map_node_id(nid)
|
passage_id = _map_node_id(nid)
|
||||||
passage_data = passages.get_passage(passage_id)
|
passage_data = passages.get_passage(passage_id)
|
||||||
txt = passage_data.get("text", "")
|
txt = passage_data.get("text", "")
|
||||||
if isinstance(txt, str) and len(txt) > 0:
|
if isinstance(txt, str) and len(txt) > 0:
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
found_indices.append(idx)
|
found_indices.append(idx)
|
||||||
else:
|
else:
|
||||||
logger.error(f"Empty text for passage ID {passage_id}")
|
logger.error(f"Empty text for passage ID {passage_id}")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Passage ID {nid} not found")
|
logger.error(f"Passage ID {nid} not found")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
logger.error(f"Exception looking up passage ID {nid}: {exc}")
|
||||||
|
|
||||||
large_distance = 1e9
|
large_distance = 1e9
|
||||||
response_distances = [large_distance] * len(node_ids)
|
response_distances = [large_distance] * len(node_ids)
|
||||||
|
|
||||||
if texts:
|
if texts:
|
||||||
try:
|
try:
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts,
|
texts,
|
||||||
model_name,
|
model_name,
|
||||||
mode=embedding_mode,
|
mode=embedding_mode,
|
||||||
provider_options=PROVIDER_OPTIONS,
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
|
)
|
||||||
|
if distance_metric == "l2":
|
||||||
|
partial = np.sum(
|
||||||
|
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
||||||
)
|
)
|
||||||
logger.info(
|
else:
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
partial = -np.dot(embeddings, query_vector)
|
||||||
)
|
|
||||||
if distance_metric == "l2":
|
|
||||||
partial = np.sum(
|
|
||||||
np.square(embeddings - query_vector.reshape(1, -1)), axis=1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
partial = -np.dot(embeddings, query_vector)
|
|
||||||
|
|
||||||
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
for pos, dval in zip(found_indices, partial.flatten().tolist()):
|
||||||
response_distances[pos] = float(dval)
|
response_distances[pos] = float(dval)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(f"Distance computation error, using sentinels: {exc}")
|
logger.error(f"Distance computation error, using sentinels: {exc}")
|
||||||
|
|
||||||
rep_socket.send(
|
rep_socket.send(
|
||||||
msgpack.packb([response_distances], use_single_float=True)
|
msgpack.packb([response_distances], use_single_float=True)
|
||||||
)
|
)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
def _handle_embedding_by_id(request: Any) -> None:
|
||||||
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
# Embedding-by-id fallback
|
|
||||||
if (
|
if (
|
||||||
isinstance(request, list)
|
isinstance(request, list)
|
||||||
and len(request) == 1
|
and len(request) == 1
|
||||||
@@ -302,8 +287,8 @@ def create_hnsw_embedding_server(
|
|||||||
dims = [len(node_ids), embedding_dim]
|
dims = [len(node_ids), embedding_dim]
|
||||||
flat_data = [0.0] * (dims[0] * dims[1])
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
texts = []
|
texts: list[str] = []
|
||||||
found_indices = []
|
found_indices: list[int] = []
|
||||||
for idx, nid in enumerate(node_ids):
|
for idx, nid in enumerate(node_ids):
|
||||||
try:
|
try:
|
||||||
passage_id = _map_node_id(nid)
|
passage_id = _map_node_id(nid)
|
||||||
@@ -382,7 +367,31 @@ def create_hnsw_embedding_server(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_handle_request(request)
|
# Model query
|
||||||
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 1
|
||||||
|
and request[0] == "__QUERY_MODEL__"
|
||||||
|
):
|
||||||
|
rep_socket.send(msgpack.packb([model_name]))
|
||||||
|
# Direct text embedding
|
||||||
|
elif (
|
||||||
|
isinstance(request, list)
|
||||||
|
and request
|
||||||
|
and all(isinstance(item, str) for item in request)
|
||||||
|
):
|
||||||
|
_handle_text_embedding(request)
|
||||||
|
# Distance calculation: [[ids], [query_vector]]
|
||||||
|
elif (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 2
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
and isinstance(request[1], list)
|
||||||
|
):
|
||||||
|
_handle_distance_request(request)
|
||||||
|
# Embedding-by-id fallback
|
||||||
|
else:
|
||||||
|
_handle_embedding_by_id(request)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
|||||||
Reference in New Issue
Block a user