refactor: logs
This commit is contained in:
@@ -11,13 +11,10 @@ import numpy as np
|
||||
import msgpack
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from typing import Optional
|
||||
import sys
|
||||
import logging
|
||||
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
# Set up logging based on environment variable
|
||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||
logging.basicConfig(
|
||||
@@ -38,8 +35,8 @@ def create_hnsw_embedding_server(
|
||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||
Simplified version using unified embedding computation module.
|
||||
"""
|
||||
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||
print(f"Using embedding mode: {embedding_mode}")
|
||||
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||
|
||||
# Add leann-core to path for unified embedding computation
|
||||
current_dir = Path(__file__).parent
|
||||
@@ -50,9 +47,9 @@ def create_hnsw_embedding_server(
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.api import PassageManager
|
||||
|
||||
print("Successfully imported unified embedding computation module")
|
||||
logger.info("Successfully imported unified embedding computation module")
|
||||
except ImportError as e:
|
||||
print(f"ERROR: Failed to import embedding computation module: {e}")
|
||||
logger.error(f"Failed to import embedding computation module: {e}")
|
||||
return
|
||||
finally:
|
||||
sys.path.pop(0)
|
||||
@@ -65,7 +62,7 @@ def create_hnsw_embedding_server(
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
if check_port(zmq_port):
|
||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
||||
logger.error(f"Port {zmq_port} is already in use")
|
||||
return
|
||||
|
||||
# Only support metadata file, fail fast for everything else
|
||||
@@ -77,7 +74,7 @@ def create_hnsw_embedding_server(
|
||||
meta = json.load(f)
|
||||
|
||||
passages = PassageManager(meta["passage_sources"])
|
||||
print(
|
||||
logger.info(
|
||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||
)
|
||||
|
||||
@@ -86,7 +83,7 @@ def create_hnsw_embedding_server(
|
||||
context = zmq.Context()
|
||||
socket = context.socket(zmq.REP)
|
||||
socket.bind(f"tcp://*:{zmq_port}")
|
||||
print(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||
|
||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||
@@ -94,7 +91,7 @@ def create_hnsw_embedding_server(
|
||||
while True:
|
||||
try:
|
||||
message_bytes = socket.recv()
|
||||
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||
|
||||
e2e_start = time.time()
|
||||
request_payload = msgpack.unpackb(message_bytes)
|
||||
@@ -131,8 +128,8 @@ def create_hnsw_embedding_server(
|
||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||
|
||||
logger.debug("Distance calculation request received")
|
||||
print(f" Node IDs: {node_ids}")
|
||||
print(f" Query vector dim: {len(query_vector)}")
|
||||
logger.debug(f" Node IDs: {node_ids}")
|
||||
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||
|
||||
# Get embeddings for node IDs
|
||||
texts = []
|
||||
@@ -142,20 +139,20 @@ def create_hnsw_embedding_server(
|
||||
txt = passage_data["text"]
|
||||
texts.append(txt)
|
||||
except KeyError:
|
||||
print(f"ERROR: Passage ID {nid} not found")
|
||||
logger.error(f"Passage ID {nid} not found")
|
||||
raise RuntimeError(
|
||||
f"FATAL: Passage with ID {nid} not found"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(
|
||||
texts, model_name, mode=embedding_mode
|
||||
)
|
||||
print(
|
||||
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Calculate distances
|
||||
@@ -170,11 +167,15 @@ def create_hnsw_embedding_server(
|
||||
response_bytes = msgpack.packb(
|
||||
[response_payload], use_single_float=True
|
||||
)
|
||||
print(f"Sending distance response with {len(distances)} distances")
|
||||
logger.debug(
|
||||
f"Sending distance response with {len(distances)} distances"
|
||||
)
|
||||
|
||||
socket.send(response_bytes)
|
||||
e2e_end = time.time()
|
||||
print(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
||||
logger.info(
|
||||
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||
)
|
||||
continue
|
||||
|
||||
# Standard embedding request (passage ID lookup)
|
||||
@@ -183,14 +184,14 @@ def create_hnsw_embedding_server(
|
||||
or len(request_payload) != 1
|
||||
or not isinstance(request_payload[0], list)
|
||||
):
|
||||
print(
|
||||
f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||
logger.error(
|
||||
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||
)
|
||||
socket.send(msgpack.packb([[], []]))
|
||||
continue
|
||||
|
||||
node_ids = request_payload[0]
|
||||
print(f"Request for {len(node_ids)} node embeddings")
|
||||
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
||||
|
||||
# Look up texts by node IDs
|
||||
texts = []
|
||||
@@ -206,19 +207,19 @@ def create_hnsw_embedding_server(
|
||||
except KeyError:
|
||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
||||
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||
raise
|
||||
|
||||
# Process embeddings
|
||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||
print(
|
||||
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
logger.info(
|
||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Serialization and response
|
||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||
print(
|
||||
f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}"
|
||||
logger.error(
|
||||
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||
)
|
||||
assert False
|
||||
|
||||
@@ -239,7 +240,7 @@ def create_hnsw_embedding_server(
|
||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error in ZMQ server loop: {e}")
|
||||
logger.error(f"Error in ZMQ server loop: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
@@ -247,14 +248,14 @@ def create_hnsw_embedding_server(
|
||||
|
||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||
zmq_thread.start()
|
||||
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("HNSW Server shutting down...")
|
||||
logger.info("HNSW Server shutting down...")
|
||||
return
|
||||
|
||||
|
||||
|
||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2547df4377...ff22e2c86b
Submodule packages/leann-backend-hnsw/third_party/msgpack-c updated: 9b801f087a...a0b2ec09da
Reference in New Issue
Block a user