refactor: embedding server
This commit is contained in:
@@ -177,20 +177,9 @@ def create_hnsw_embedding_server(
|
|||||||
return []
|
return []
|
||||||
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
return [[0, int(embedding_dim) if embedding_dim > 0 else 0], []]
|
||||||
|
|
||||||
def _handle_request(request):
|
def _handle_text_embedding(request: list[str]) -> None:
|
||||||
nonlocal last_request_type, last_request_length
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
# Model query
|
|
||||||
if isinstance(request, list) and len(request) == 1 and request[0] == "__QUERY_MODEL__":
|
|
||||||
rep_socket.send(msgpack.packb([model_name]))
|
|
||||||
return
|
|
||||||
|
|
||||||
# Direct text embedding
|
|
||||||
if (
|
|
||||||
isinstance(request, list)
|
|
||||||
and request
|
|
||||||
and all(isinstance(item, str) for item in request)
|
|
||||||
):
|
|
||||||
e2e_start = time.time()
|
e2e_start = time.time()
|
||||||
last_request_type = "text"
|
last_request_type = "text"
|
||||||
last_request_length = len(request)
|
last_request_length = len(request)
|
||||||
@@ -205,15 +194,10 @@ def create_hnsw_embedding_server(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
f"⏱️ Direct text embedding E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
# Distance calculation: [[ids], [query_vector]]
|
def _handle_distance_request(request: list[Any]) -> None:
|
||||||
if (
|
nonlocal last_request_type, last_request_length
|
||||||
isinstance(request, list)
|
|
||||||
and len(request) == 2
|
|
||||||
and isinstance(request[0], list)
|
|
||||||
and isinstance(request[1], list)
|
|
||||||
):
|
|
||||||
e2e_start = time.time()
|
e2e_start = time.time()
|
||||||
node_ids = request[0]
|
node_ids = request[0]
|
||||||
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
if len(node_ids) == 1 and isinstance(node_ids[0], list):
|
||||||
@@ -276,9 +260,10 @@ def create_hnsw_embedding_server(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
|
||||||
# Embedding-by-id fallback
|
def _handle_embedding_by_id(request: Any) -> None:
|
||||||
|
nonlocal last_request_type, last_request_length
|
||||||
|
|
||||||
if (
|
if (
|
||||||
isinstance(request, list)
|
isinstance(request, list)
|
||||||
and len(request) == 1
|
and len(request) == 1
|
||||||
@@ -302,8 +287,8 @@ def create_hnsw_embedding_server(
|
|||||||
dims = [len(node_ids), embedding_dim]
|
dims = [len(node_ids), embedding_dim]
|
||||||
flat_data = [0.0] * (dims[0] * dims[1])
|
flat_data = [0.0] * (dims[0] * dims[1])
|
||||||
|
|
||||||
texts = []
|
texts: list[str] = []
|
||||||
found_indices = []
|
found_indices: list[int] = []
|
||||||
for idx, nid in enumerate(node_ids):
|
for idx, nid in enumerate(node_ids):
|
||||||
try:
|
try:
|
||||||
passage_id = _map_node_id(nid)
|
passage_id = _map_node_id(nid)
|
||||||
@@ -382,7 +367,31 @@ def create_hnsw_embedding_server(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_handle_request(request)
|
# Model query
|
||||||
|
if (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 1
|
||||||
|
and request[0] == "__QUERY_MODEL__"
|
||||||
|
):
|
||||||
|
rep_socket.send(msgpack.packb([model_name]))
|
||||||
|
# Direct text embedding
|
||||||
|
elif (
|
||||||
|
isinstance(request, list)
|
||||||
|
and request
|
||||||
|
and all(isinstance(item, str) for item in request)
|
||||||
|
):
|
||||||
|
_handle_text_embedding(request)
|
||||||
|
# Distance calculation: [[ids], [query_vector]]
|
||||||
|
elif (
|
||||||
|
isinstance(request, list)
|
||||||
|
and len(request) == 2
|
||||||
|
and isinstance(request[0], list)
|
||||||
|
and isinstance(request[1], list)
|
||||||
|
):
|
||||||
|
_handle_distance_request(request)
|
||||||
|
# Embedding-by-id fallback
|
||||||
|
else:
|
||||||
|
_handle_embedding_by_id(request)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
logger.info("Shutdown in progress, ignoring ZMQ error")
|
logger.info("Shutdown in progress, ignoring ZMQ error")
|
||||||
|
|||||||
Reference in New Issue
Block a user