refactor: logs

This commit is contained in:
Andy Lee
2025-07-21 22:45:24 -07:00
parent f7af6805fa
commit 573313f0b6
6 changed files with 87 additions and 103 deletions

View File

@@ -11,13 +11,10 @@ import numpy as np
import msgpack
import json
from pathlib import Path
from typing import Dict, Any, Optional, Union
from typing import Optional
import sys
import logging
RED = "\033[91m"
RESET = "\033[0m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
@@ -38,8 +35,8 @@ def create_hnsw_embedding_server(
Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module.
"""
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Using embedding mode: {embedding_mode}")
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
@@ -50,9 +47,9 @@ def create_hnsw_embedding_server(
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
print("Successfully imported unified embedding computation module")
logger.info("Successfully imported unified embedding computation module")
except ImportError as e:
print(f"ERROR: Failed to import embedding computation module: {e}")
logger.error(f"Failed to import embedding computation module: {e}")
return
finally:
sys.path.pop(0)
@@ -65,7 +62,7 @@ def create_hnsw_embedding_server(
return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}")
logger.error(f"Port {zmq_port} is already in use")
return
# Only support metadata file, fail fast for everything else
@@ -77,7 +74,7 @@ def create_hnsw_embedding_server(
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
print(
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
@@ -86,7 +83,7 @@ def create_hnsw_embedding_server(
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}")
print(f"HNSW ZMQ server listening on port {zmq_port}")
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
@@ -94,7 +91,7 @@ def create_hnsw_embedding_server(
while True:
try:
message_bytes = socket.recv()
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
e2e_start = time.time()
request_payload = msgpack.unpackb(message_bytes)
@@ -131,8 +128,8 @@ def create_hnsw_embedding_server(
query_vector = np.array(request_payload[1], dtype=np.float32)
logger.debug("Distance calculation request received")
print(f" Node IDs: {node_ids}")
print(f" Query vector dim: {len(query_vector)}")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Get embeddings for node IDs
texts = []
@@ -142,20 +139,20 @@ def create_hnsw_embedding_server(
txt = passage_data["text"]
texts.append(txt)
except KeyError:
print(f"ERROR: Passage ID {nid} not found")
logger.error(f"Passage ID {nid} not found")
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Calculate distances
@@ -170,11 +167,15 @@ def create_hnsw_embedding_server(
response_bytes = msgpack.packb(
[response_payload], use_single_float=True
)
print(f"Sending distance response with {len(distances)} distances")
logger.debug(
f"Sending distance response with {len(distances)} distances"
)
socket.send(response_bytes)
e2e_end = time.time()
print(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
continue
# Standard embedding request (passage ID lookup)
@@ -183,14 +184,14 @@ def create_hnsw_embedding_server(
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
print(
f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
logger.error(
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
node_ids = request_payload[0]
print(f"Request for {len(node_ids)} node embeddings")
logger.debug(f"Request for {len(node_ids)} node embeddings")
# Look up texts by node IDs
texts = []
@@ -206,19 +207,19 @@ def create_hnsw_embedding_server(
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Serialization and response
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
print(
f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}"
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
assert False
@@ -239,7 +240,7 @@ def create_hnsw_embedding_server(
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
print(f"Error in ZMQ server loop: {e}")
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
traceback.print_exc()
@@ -247,14 +248,14 @@ def create_hnsw_embedding_server(
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("HNSW Server shutting down...")
logger.info("HNSW Server shutting down...")
return