refactor: embedding server
This commit is contained in:
@@ -177,20 +177,9 @@ def create_hnsw_embedding_server(
|
||||
return []
|
||||
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||
|
||||
def _handle_request(request):
|
||||
def _handle_text_embedding(request: list[str]) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
# Model query
|
||||
if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
||||
rep_socket.send(msgpack.packb([model_name]))
|
||||
return
|
||||
|
||||
# Direct text embedding
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and request
|
||||
and all(isinstance(item, str) for item in request)
|
||||
):
|
||||
e2e_start = time.time()
|
||||
last_request_type = "text"
|
||||
last_request_length = len(request)
|
||||
@@ -205,15 +194,10 @@ def create_hnsw_embedding_server(
|
||||
logger.info(
|
||||
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
return
|
||||
|
||||
# Distance calculation: [[ids], [query_vector]]
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 2
|
||||
and isinstance(request[0], list)
|
||||
and isinstance(request[1], list)
|
||||
):
|
||||
def _handle_distance_request(request: list[Any]) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
e2e_start = time.time()
|
||||
node_ids = request[0]
|
||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||
@@ -276,9 +260,10 @@ def create_hnsw_embedding_server(
|
||||
logger.info(
|
||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
return
|
||||
|
||||
# Embedding-by-id fallback
|
||||
def _handle_embedding_by_id(request: Any) -> None:
|
||||
nonlocal last_request_type, last_request_length
|
||||
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 1
|
||||
@@ -302,8 +287,8 @@ def create_hnsw_embedding_server(
|
||||
dims = [len(node_ids), embedding_dim]
|
||||
flat_data = [0.0] * (dims[0] * dims[1])
|
||||
|
||||
texts = []
|
||||
found_indices = []
|
||||
texts: list[str] = []
|
||||
found_indices: list[int] = []
|
||||
for idx, nid in enumerate(node_ids):
|
||||
try:
|
||||
passage_id = _map_node_id(nid)
|
||||
@@ -382,7 +367,31 @@ def create_hnsw_embedding_server(
|
||||
continue
|
||||
|
||||
try:
|
||||
_handle_request(request)
|
||||
# Model query
|
||||
if (
|
||||
isinstance(request, list)
|
||||
and len(request) == 1
|
||||
and request[0] == "__QUERY_MODEL__"
|
||||
):
|
||||
rep_socket.send(msgpack.packb([model_name]))
|
||||
# Direct text embedding
|
||||
elif (
|
||||
isinstance(request, list)
|
||||
and request
|
||||
and all(isinstance(item, str) for item in request)
|
||||
):
|
||||
_handle_text_embedding(request)
|
||||
# Distance calculation: [[ids], [query_vector]]
|
||||
elif (
|
||||
isinstance(request, list)
|
||||
and len(request) == 2
|
||||
and isinstance(request[0], list)
|
||||
and isinstance(request[1], list)
|
||||
):
|
||||
_handle_distance_request(request)
|
||||
# Embedding-by-id fallback
|
||||
else:
|
||||
_handle_embedding_by_id(request)
|
||||
except Exception as exc:
|
||||
if shutdown_event.is_set():
|
||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||
|
||||
Reference in New Issue
Block a user