refactor: logs

This commit is contained in:
Andy Lee
2025-07-21 22:45:24 -07:00
parent f7af6805fa
commit 573313f0b6
6 changed files with 87 additions and 103 deletions

View File

@@ -11,13 +11,10 @@ import numpy as np
import msgpack
import json
from pathlib import Path
from typing import Dict, Any, Optional, Union
from typing import Optional
import sys
import logging
RED = "\033[91m"
RESET = "\033[0m"
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
@@ -38,8 +35,8 @@ def create_hnsw_embedding_server(
Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module.
"""
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Using embedding mode: {embedding_mode}")
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
logger.info(f"Using embedding mode: {embedding_mode}")
# Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent
@@ -50,9 +47,9 @@ def create_hnsw_embedding_server(
from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager
print("Successfully imported unified embedding computation module")
logger.info("Successfully imported unified embedding computation module")
except ImportError as e:
print(f"ERROR: Failed to import embedding computation module: {e}")
logger.error(f"Failed to import embedding computation module: {e}")
return
finally:
sys.path.pop(0)
@@ -65,7 +62,7 @@ def create_hnsw_embedding_server(
return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}")
logger.error(f"Port {zmq_port} is already in use")
return
# Only support metadata file, fail fast for everything else
@@ -77,7 +74,7 @@ def create_hnsw_embedding_server(
meta = json.load(f)
passages = PassageManager(meta["passage_sources"])
print(
logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
)
@@ -86,7 +83,7 @@ def create_hnsw_embedding_server(
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}")
print(f"HNSW ZMQ server listening on port {zmq_port}")
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000)
@@ -94,7 +91,7 @@ def create_hnsw_embedding_server(
while True:
try:
message_bytes = socket.recv()
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
e2e_start = time.time()
request_payload = msgpack.unpackb(message_bytes)
@@ -131,8 +128,8 @@ def create_hnsw_embedding_server(
query_vector = np.array(request_payload[1], dtype=np.float32)
logger.debug("Distance calculation request received")
print(f" Node IDs: {node_ids}")
print(f" Query vector dim: {len(query_vector)}")
logger.debug(f" Node IDs: {node_ids}")
logger.debug(f" Query vector dim: {len(query_vector)}")
# Get embeddings for node IDs
texts = []
@@ -142,20 +139,20 @@ def create_hnsw_embedding_server(
txt = passage_data["text"]
texts.append(txt)
except KeyError:
print(f"ERROR: Passage ID {nid} not found")
logger.error(f"Passage ID {nid} not found")
raise RuntimeError(
f"FATAL: Passage with ID {nid} not found"
)
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Calculate distances
@@ -170,11 +167,15 @@ def create_hnsw_embedding_server(
response_bytes = msgpack.packb(
[response_payload], use_single_float=True
)
print(f"Sending distance response with {len(distances)} distances")
logger.debug(
f"Sending distance response with {len(distances)} distances"
)
socket.send(response_bytes)
e2e_end = time.time()
print(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
continue
# Standard embedding request (passage ID lookup)
@@ -183,14 +184,14 @@ def create_hnsw_embedding_server(
or len(request_payload) != 1
or not isinstance(request_payload[0], list)
):
print(
f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
logger.error(
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
)
socket.send(msgpack.packb([[], []]))
continue
node_ids = request_payload[0]
print(f"Request for {len(node_ids)} node embeddings")
logger.debug(f"Request for {len(node_ids)} node embeddings")
# Look up texts by node IDs
texts = []
@@ -206,19 +207,19 @@ def create_hnsw_embedding_server(
except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
logger.error(f"Exception looking up passage ID {nid}: {e}")
raise
# Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
print(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
# Serialization and response
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
print(
f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}"
logger.error(
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
)
assert False
@@ -239,7 +240,7 @@ def create_hnsw_embedding_server(
logger.debug("ZMQ socket timeout, continuing to listen")
continue
except Exception as e:
print(f"Error in ZMQ server loop: {e}")
logger.error(f"Error in ZMQ server loop: {e}")
import traceback
traceback.print_exc()
@@ -247,14 +248,14 @@ def create_hnsw_embedding_server(
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start()
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
print("HNSW Server shutting down...")
logger.info("HNSW Server shutting down...")
return

View File

@@ -113,7 +113,7 @@ class PassageManager:
for source in passage_sources:
assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"]
index_file = source["index_path"]
index_file = source["index_path"] # .idx file
if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f:

View File

@@ -6,7 +6,7 @@ Preserves all optimization parameters to ensure performance
import numpy as np
import torch
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any
import logging
logger = logging.getLogger(__name__)
@@ -16,7 +16,10 @@ _model_cache: Dict[str, Any] = {}
def compute_embeddings(
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False
texts: List[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
) -> np.ndarray:
"""
Unified embedding computation entry point
@@ -30,7 +33,9 @@ def compute_embeddings(
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(texts, model_name, is_build=is_build)
return compute_embeddings_sentence_transformers(
texts, model_name, is_build=is_build
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
elif mode == "mlx":
@@ -65,7 +70,7 @@ def compute_embeddings_sentence_transformers(
# Create cache key
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}"
# Check if model is already cached
if cache_key in _model_cache:
print(f"INFO: Using cached model: {model_name}")

View File

@@ -1,14 +1,22 @@
import threading
import time
import atexit
import socket
import subprocess
import sys
import os
import logging
from pathlib import Path
from typing import Optional
import select
import psutil
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def _check_port(port: int) -> bool:
"""Check if a port is in use"""
@@ -36,11 +44,11 @@ def _check_process_matches_config(
cmdline, port, expected_model, expected_passages_file
)
print(f"DEBUG: No process found listening on port {port}")
logger.debug(f"No process found listening on port {port}")
return False
except Exception as e:
print(f"WARNING: Could not check process on port {port}: {e}")
logger.warning(f"Could not check process on port {port}: {e}")
return False
@@ -61,7 +69,7 @@ def _check_cmdline_matches_config(
) -> bool:
"""Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline)
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
logger.debug(f"Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server
is_embedding_server = any(
@@ -74,7 +82,7 @@ def _check_cmdline_matches_config(
)
if not is_embedding_server:
print(f"DEBUG: Process on port {port} is not our embedding server")
logger.debug(f"Process on port {port} is not our embedding server")
return False
# Check model name
@@ -84,8 +92,8 @@ def _check_cmdline_matches_config(
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches
print(
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
logger.debug(
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
)
return result
@@ -132,10 +140,10 @@ def _find_compatible_port_or_next_available(
# Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file):
print(f"Found compatible server on port {port}")
logger.info(f"Found compatible server on port {port}")
return port, True
else:
print(f"⚠️ Port {port} has incompatible server, trying next port...")
logger.info(f"Port {port} has incompatible server, trying next port...")
raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
@@ -194,17 +202,17 @@ class EmbeddingServerManager:
port, model_name, passages_file
)
except RuntimeError as e:
print(f"{e}")
logger.error(str(e))
return False, port
if is_compatible:
print(f"Using existing compatible server on port {actual_port}")
logger.info(f"Using existing compatible server on port {actual_port}")
self.server_port = actual_port
self.server_process = None # We don't own this process
return True, actual_port
if actual_port != port:
print(f"⚠️ Using port {actual_port} instead of {port}")
logger.info(f"Using port {actual_port} instead of {port}")
# Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
@@ -221,19 +229,21 @@ class EmbeddingServerManager:
return False
if _check_process_matches_config(self.server_port, model_name, passages_file):
print(
f"Existing server process (PID {self.server_process.pid}) is compatible"
logger.info(
f"Existing server process (PID {self.server_process.pid}) is compatible"
)
return True
print("⚠️ Existing server process is incompatible. Should start a new server.")
logger.info(
"Existing server process is incompatible. Should start a new server."
)
return False
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
print(f"INFO: Starting embedding server on port {port}...")
logger.info(f"Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
@@ -241,7 +251,7 @@ class EmbeddingServerManager:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
except Exception as e:
print(f"❌ ERROR: Failed to start embedding server: {e}")
logger.error(f"Failed to start embedding server: {e}")
return False, port
def _build_server_command(
@@ -268,20 +278,18 @@ class EmbeddingServerManager:
def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Command: {' '.join(command)}")
logger.info(f"Command: {' '.join(command)}")
# Let server output go directly to console
# The server will respect LEANN_LOG_LEVEL environment variable
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
bufsize=1,
universal_newlines=True,
stdout=None, # Direct to console
stderr=None, # Direct to console
)
self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}")
logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process
if not self._atexit_registered:
@@ -294,49 +302,19 @@ class EmbeddingServerManager:
max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)):
if _check_port(port):
print("Embedding server is ready!")
threading.Thread(target=self._log_monitor, daemon=True).start()
logger.info("Embedding server is ready!")
return True, port
if self.server_process.poll() is not None:
print("❌ ERROR: Server terminated during startup.")
self._print_recent_output()
if self.server_process and self.server_process.poll() is not None:
logger.error("Server terminated during startup.")
return False, port
time.sleep(wait_interval)
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
logger.error(f"Server failed to start within {max_wait} seconds.")
self.stop_server()
return False, port
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:
print(f"[{self.backend_module_name} OUTPUT]: {output}")
except Exception as e:
print(f"Error reading server output: {e}")
def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr."""
if not self.server_process:
return
try:
if self.server_process.stdout:
while True:
line = self.server_process.stdout.readline()
if not line:
break
print(
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
)
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self):
"""Stops the embedding server process if it's running."""
if not self.server_process:
@@ -347,17 +325,17 @@ class EmbeddingServerManager:
self.server_process = None
return
print(
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
logger.info(
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
)
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
print(f"INFO: Server process {self.server_process.pid} terminated.")
logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired:
print(
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
logger.warning(
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
)
self.server_process.kill()