refactor: logs

This commit is contained in:
Andy Lee
2025-07-21 22:45:24 -07:00
parent f7af6805fa
commit 573313f0b6
6 changed files with 87 additions and 103 deletions

View File

@@ -11,13 +11,10 @@ import numpy as np
import msgpack import msgpack
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, Any, Optional, Union from typing import Optional
import sys import sys
import logging import logging
RED = "\033[91m"
RESET = "\033[0m"
# Set up logging based on environment variable # Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logging.basicConfig( logging.basicConfig(
@@ -38,8 +35,8 @@ def create_hnsw_embedding_server(
Create and start a ZMQ-based embedding server for HNSW backend. Create and start a ZMQ-based embedding server for HNSW backend.
Simplified version using unified embedding computation module. Simplified version using unified embedding computation module.
""" """
print(f"Starting HNSW server on port {zmq_port} with model {model_name}") logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
print(f"Using embedding mode: {embedding_mode}") logger.info(f"Using embedding mode: {embedding_mode}")
# Add leann-core to path for unified embedding computation # Add leann-core to path for unified embedding computation
current_dir = Path(__file__).parent current_dir = Path(__file__).parent
@@ -50,9 +47,9 @@ def create_hnsw_embedding_server(
from leann.embedding_compute import compute_embeddings from leann.embedding_compute import compute_embeddings
from leann.api import PassageManager from leann.api import PassageManager
print("Successfully imported unified embedding computation module") logger.info("Successfully imported unified embedding computation module")
except ImportError as e: except ImportError as e:
print(f"ERROR: Failed to import embedding computation module: {e}") logger.error(f"Failed to import embedding computation module: {e}")
return return
finally: finally:
sys.path.pop(0) sys.path.pop(0)
@@ -65,7 +62,7 @@ def create_hnsw_embedding_server(
return s.connect_ex(("localhost", port)) == 0 return s.connect_ex(("localhost", port)) == 0
if check_port(zmq_port): if check_port(zmq_port):
print(f"{RED}Port {zmq_port} is already in use{RESET}") logger.error(f"Port {zmq_port} is already in use")
return return
# Only support metadata file, fail fast for everything else # Only support metadata file, fail fast for everything else
@@ -77,7 +74,7 @@ def create_hnsw_embedding_server(
meta = json.load(f) meta = json.load(f)
passages = PassageManager(meta["passage_sources"]) passages = PassageManager(meta["passage_sources"])
print( logger.info(
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
) )
@@ -86,7 +83,7 @@ def create_hnsw_embedding_server(
context = zmq.Context() context = zmq.Context()
socket = context.socket(zmq.REP) socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{zmq_port}") socket.bind(f"tcp://*:{zmq_port}")
print(f"HNSW ZMQ server listening on port {zmq_port}") logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
socket.setsockopt(zmq.RCVTIMEO, 300000) socket.setsockopt(zmq.RCVTIMEO, 300000)
socket.setsockopt(zmq.SNDTIMEO, 300000) socket.setsockopt(zmq.SNDTIMEO, 300000)
@@ -94,7 +91,7 @@ def create_hnsw_embedding_server(
while True: while True:
try: try:
message_bytes = socket.recv() message_bytes = socket.recv()
print(f"Received ZMQ request of size {len(message_bytes)} bytes") logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
e2e_start = time.time() e2e_start = time.time()
request_payload = msgpack.unpackb(message_bytes) request_payload = msgpack.unpackb(message_bytes)
@@ -131,8 +128,8 @@ def create_hnsw_embedding_server(
query_vector = np.array(request_payload[1], dtype=np.float32) query_vector = np.array(request_payload[1], dtype=np.float32)
logger.debug("Distance calculation request received") logger.debug("Distance calculation request received")
print(f" Node IDs: {node_ids}") logger.debug(f" Node IDs: {node_ids}")
print(f" Query vector dim: {len(query_vector)}") logger.debug(f" Query vector dim: {len(query_vector)}")
# Get embeddings for node IDs # Get embeddings for node IDs
texts = [] texts = []
@@ -142,20 +139,20 @@ def create_hnsw_embedding_server(
txt = passage_data["text"] txt = passage_data["text"]
texts.append(txt) texts.append(txt)
except KeyError: except KeyError:
print(f"ERROR: Passage ID {nid} not found") logger.error(f"Passage ID {nid} not found")
raise RuntimeError( raise RuntimeError(
f"FATAL: Passage with ID {nid} not found" f"FATAL: Passage with ID {nid} not found"
) )
except Exception as e: except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {e}")
raise raise
# Process embeddings # Process embeddings
embeddings = compute_embeddings( embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode texts, model_name, mode=embedding_mode
) )
print( logger.info(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
# Calculate distances # Calculate distances
@@ -170,11 +167,15 @@ def create_hnsw_embedding_server(
response_bytes = msgpack.packb( response_bytes = msgpack.packb(
[response_payload], use_single_float=True [response_payload], use_single_float=True
) )
print(f"Sending distance response with {len(distances)} distances") logger.debug(
f"Sending distance response with {len(distances)} distances"
)
socket.send(response_bytes) socket.send(response_bytes)
e2e_end = time.time() e2e_end = time.time()
print(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") logger.info(
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
)
continue continue
# Standard embedding request (passage ID lookup) # Standard embedding request (passage ID lookup)
@@ -183,14 +184,14 @@ def create_hnsw_embedding_server(
or len(request_payload) != 1 or len(request_payload) != 1
or not isinstance(request_payload[0], list) or not isinstance(request_payload[0], list)
): ):
print( logger.error(
f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}" f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
) )
socket.send(msgpack.packb([[], []])) socket.send(msgpack.packb([[], []]))
continue continue
node_ids = request_payload[0] node_ids = request_payload[0]
print(f"Request for {len(node_ids)} node embeddings") logger.debug(f"Request for {len(node_ids)} node embeddings")
# Look up texts by node IDs # Look up texts by node IDs
texts = [] texts = []
@@ -206,19 +207,19 @@ def create_hnsw_embedding_server(
except KeyError: except KeyError:
raise RuntimeError(f"FATAL: Passage with ID {nid} not found") raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
except Exception as e: except Exception as e:
print(f"ERROR: Exception looking up passage ID {nid}: {e}") logger.error(f"Exception looking up passage ID {nid}: {e}")
raise raise
# Process embeddings # Process embeddings
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
print( logger.info(
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
# Serialization and response # Serialization and response
if np.isnan(embeddings).any() or np.isinf(embeddings).any(): if np.isnan(embeddings).any() or np.isinf(embeddings).any():
print( logger.error(
f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}" f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
) )
assert False assert False
@@ -239,7 +240,7 @@ def create_hnsw_embedding_server(
logger.debug("ZMQ socket timeout, continuing to listen") logger.debug("ZMQ socket timeout, continuing to listen")
continue continue
except Exception as e: except Exception as e:
print(f"Error in ZMQ server loop: {e}") logger.error(f"Error in ZMQ server loop: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
@@ -247,14 +248,14 @@ def create_hnsw_embedding_server(
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
zmq_thread.start() zmq_thread.start()
print(f"Started HNSW ZMQ server thread on port {zmq_port}") logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
# Keep the main thread alive # Keep the main thread alive
try: try:
while True: while True:
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
print("HNSW Server shutting down...") logger.info("HNSW Server shutting down...")
return return

View File

@@ -113,7 +113,7 @@ class PassageManager:
for source in passage_sources: for source in passage_sources:
assert source["type"] == "jsonl", "only jsonl is supported" assert source["type"] == "jsonl", "only jsonl is supported"
passage_file = source["path"] passage_file = source["path"]
index_file = source["index_path"] index_file = source["index_path"] # .idx file
if not Path(index_file).exists(): if not Path(index_file).exists():
raise FileNotFoundError(f"Passage index file not found: {index_file}") raise FileNotFoundError(f"Passage index file not found: {index_file}")
with open(index_file, "rb") as f: with open(index_file, "rb") as f:

View File

@@ -6,7 +6,7 @@ Preserves all optimization parameters to ensure performance
import numpy as np import numpy as np
import torch import torch
from typing import List, Dict, Any, Optional from typing import List, Dict, Any
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,7 +16,10 @@ _model_cache: Dict[str, Any] = {}
def compute_embeddings( def compute_embeddings(
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False texts: List[str],
model_name: str,
mode: str = "sentence-transformers",
is_build: bool = False,
) -> np.ndarray: ) -> np.ndarray:
""" """
Unified embedding computation entry point Unified embedding computation entry point
@@ -30,7 +33,9 @@ def compute_embeddings(
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
""" """
if mode == "sentence-transformers": if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(texts, model_name, is_build=is_build) return compute_embeddings_sentence_transformers(
texts, model_name, is_build=is_build
)
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai(texts, model_name) return compute_embeddings_openai(texts, model_name)
elif mode == "mlx": elif mode == "mlx":

View File

@@ -1,14 +1,22 @@
import threading
import time import time
import atexit import atexit
import socket import socket
import subprocess import subprocess
import sys import sys
import os
import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import select
import psutil import psutil
# Set up logging based on environment variable
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.INFO),
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def _check_port(port: int) -> bool: def _check_port(port: int) -> bool:
"""Check if a port is in use""" """Check if a port is in use"""
@@ -36,11 +44,11 @@ def _check_process_matches_config(
cmdline, port, expected_model, expected_passages_file cmdline, port, expected_model, expected_passages_file
) )
print(f"DEBUG: No process found listening on port {port}") logger.debug(f"No process found listening on port {port}")
return False return False
except Exception as e: except Exception as e:
print(f"WARNING: Could not check process on port {port}: {e}") logger.warning(f"Could not check process on port {port}: {e}")
return False return False
@@ -61,7 +69,7 @@ def _check_cmdline_matches_config(
) -> bool: ) -> bool:
"""Check if command line matches our expected configuration.""" """Check if command line matches our expected configuration."""
cmdline_str = " ".join(cmdline) cmdline_str = " ".join(cmdline)
print(f"DEBUG: Found process on port {port}: {cmdline_str}") logger.debug(f"Found process on port {port}: {cmdline_str}")
# Check if it's our embedding server # Check if it's our embedding server
is_embedding_server = any( is_embedding_server = any(
@@ -74,7 +82,7 @@ def _check_cmdline_matches_config(
) )
if not is_embedding_server: if not is_embedding_server:
print(f"DEBUG: Process on port {port} is not our embedding server") logger.debug(f"Process on port {port} is not our embedding server")
return False return False
# Check model name # Check model name
@@ -84,8 +92,8 @@ def _check_cmdline_matches_config(
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file) passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
result = model_matches and passages_matches result = model_matches and passages_matches
print( logger.debug(
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}" f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
) )
return result return result
@@ -132,10 +140,10 @@ def _find_compatible_port_or_next_available(
# Port is in use, check if it's compatible # Port is in use, check if it's compatible
if _check_process_matches_config(port, model_name, passages_file): if _check_process_matches_config(port, model_name, passages_file):
print(f"Found compatible server on port {port}") logger.info(f"Found compatible server on port {port}")
return port, True return port, True
else: else:
print(f"⚠️ Port {port} has incompatible server, trying next port...") logger.info(f"Port {port} has incompatible server, trying next port...")
raise RuntimeError( raise RuntimeError(
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}" f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
@@ -194,17 +202,17 @@ class EmbeddingServerManager:
port, model_name, passages_file port, model_name, passages_file
) )
except RuntimeError as e: except RuntimeError as e:
print(f"{e}") logger.error(str(e))
return False, port return False, port
if is_compatible: if is_compatible:
print(f"Using existing compatible server on port {actual_port}") logger.info(f"Using existing compatible server on port {actual_port}")
self.server_port = actual_port self.server_port = actual_port
self.server_process = None # We don't own this process self.server_process = None # We don't own this process
return True, actual_port return True, actual_port
if actual_port != port: if actual_port != port:
print(f"⚠️ Using port {actual_port} instead of {port}") logger.info(f"Using port {actual_port} instead of {port}")
# Start new server # Start new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
@@ -221,19 +229,21 @@ class EmbeddingServerManager:
return False return False
if _check_process_matches_config(self.server_port, model_name, passages_file): if _check_process_matches_config(self.server_port, model_name, passages_file):
print( logger.info(
f"Existing server process (PID {self.server_process.pid}) is compatible" f"Existing server process (PID {self.server_process.pid}) is compatible"
) )
return True return True
print("⚠️ Existing server process is incompatible. Should start a new server.") logger.info(
"Existing server process is incompatible. Should start a new server."
)
return False return False
def _start_new_server( def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs self, port: int, model_name: str, embedding_mode: str, **kwargs
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start a new embedding server on the given port.""" """Start a new embedding server on the given port."""
print(f"INFO: Starting embedding server on port {port}...") logger.info(f"Starting embedding server on port {port}...")
command = self._build_server_command(port, model_name, embedding_mode, **kwargs) command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
@@ -241,7 +251,7 @@ class EmbeddingServerManager:
self._launch_server_process(command, port) self._launch_server_process(command, port)
return self._wait_for_server_ready(port) return self._wait_for_server_ready(port)
except Exception as e: except Exception as e:
print(f"❌ ERROR: Failed to start embedding server: {e}") logger.error(f"Failed to start embedding server: {e}")
return False, port return False, port
def _build_server_command( def _build_server_command(
@@ -268,20 +278,18 @@ class EmbeddingServerManager:
def _launch_server_process(self, command: list, port: int) -> None: def _launch_server_process(self, command: list, port: int) -> None:
"""Launch the server process.""" """Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
print(f"INFO: Command: {' '.join(command)}") logger.info(f"Command: {' '.join(command)}")
# Let server output go directly to console
# The server will respect LEANN_LOG_LEVEL environment variable
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,
cwd=project_root, cwd=project_root,
stdout=subprocess.PIPE, stdout=None, # Direct to console
stderr=subprocess.STDOUT, stderr=None, # Direct to console
text=True,
encoding="utf-8",
bufsize=1,
universal_newlines=True,
) )
self.server_port = port self.server_port = port
print(f"INFO: Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process # Register atexit callback only when we actually start a process
if not self._atexit_registered: if not self._atexit_registered:
@@ -294,49 +302,19 @@ class EmbeddingServerManager:
max_wait, wait_interval = 120, 0.5 max_wait, wait_interval = 120, 0.5
for _ in range(int(max_wait / wait_interval)): for _ in range(int(max_wait / wait_interval)):
if _check_port(port): if _check_port(port):
print("Embedding server is ready!") logger.info("Embedding server is ready!")
threading.Thread(target=self._log_monitor, daemon=True).start()
return True, port return True, port
if self.server_process.poll() is not None: if self.server_process and self.server_process.poll() is not None:
print("❌ ERROR: Server terminated during startup.") logger.error("Server terminated during startup.")
self._print_recent_output()
return False, port return False, port
time.sleep(wait_interval) time.sleep(wait_interval)
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.") logger.error(f"Server failed to start within {max_wait} seconds.")
self.stop_server() self.stop_server()
return False, port return False, port
def _print_recent_output(self):
"""Print any recent output from the server process."""
if not self.server_process or not self.server_process.stdout:
return
try:
if select.select([self.server_process.stdout], [], [], 0)[0]:
output = self.server_process.stdout.read()
if output:
print(f"[{self.backend_module_name} OUTPUT]: {output}")
except Exception as e:
print(f"Error reading server output: {e}")
def _log_monitor(self):
"""Monitors and prints the server's stdout and stderr."""
if not self.server_process:
return
try:
if self.server_process.stdout:
while True:
line = self.server_process.stdout.readline()
if not line:
break
print(
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
)
except Exception as e:
print(f"Log monitor error: {e}")
def stop_server(self): def stop_server(self):
"""Stops the embedding server process if it's running.""" """Stops the embedding server process if it's running."""
if not self.server_process: if not self.server_process:
@@ -347,17 +325,17 @@ class EmbeddingServerManager:
self.server_process = None self.server_process = None
return return
print( logger.info(
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..." f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
) )
self.server_process.terminate() self.server_process.terminate()
try: try:
self.server_process.wait(timeout=5) self.server_process.wait(timeout=5)
print(f"INFO: Server process {self.server_process.pid} terminated.") logger.info(f"Server process {self.server_process.pid} terminated.")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
print( logger.warning(
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it." f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
) )
self.server_process.kill() self.server_process.kill()