From 573313f0b6b8cf423670868a032897942004adb1 Mon Sep 17 00:00:00 2001 From: Andy Lee Date: Mon, 21 Jul 2025 22:45:24 -0700 Subject: [PATCH] refactor: logs --- .../hnsw_embedding_server.py | 63 +++++----- packages/leann-backend-hnsw/third_party/faiss | 2 +- .../leann-backend-hnsw/third_party/msgpack-c | 2 +- packages/leann-core/src/leann/api.py | 2 +- .../leann-core/src/leann/embedding_compute.py | 13 ++- .../src/leann/embedding_server_manager.py | 108 +++++++----------- 6 files changed, 87 insertions(+), 103 deletions(-) diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py index eaad3da..94b6529 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_embedding_server.py @@ -11,13 +11,10 @@ import numpy as np import msgpack import json from pathlib import Path -from typing import Dict, Any, Optional, Union +from typing import Optional import sys import logging -RED = "\033[91m" -RESET = "\033[0m" - # Set up logging based on environment variable LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() logging.basicConfig( @@ -38,8 +35,8 @@ def create_hnsw_embedding_server( Create and start a ZMQ-based embedding server for HNSW backend. Simplified version using unified embedding computation module. """ - print(f"Starting HNSW server on port {zmq_port} with model {model_name}") - print(f"Using embedding mode: {embedding_mode}") + logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}") + logger.info(f"Using embedding mode: {embedding_mode}") # Add leann-core to path for unified embedding computation current_dir = Path(__file__).parent @@ -50,9 +47,9 @@ def create_hnsw_embedding_server( from leann.embedding_compute import compute_embeddings from leann.api import PassageManager - print("Successfully imported unified embedding computation module") + logger.info("Successfully imported unified embedding computation module") except ImportError as e: - print(f"ERROR: Failed to import embedding computation module: {e}") + logger.error(f"Failed to import embedding computation module: {e}") return finally: sys.path.pop(0) @@ -65,7 +62,7 @@ def create_hnsw_embedding_server( return s.connect_ex(("localhost", port)) == 0 if check_port(zmq_port): - print(f"{RED}Port {zmq_port} is already in use{RESET}") + logger.error(f"Port {zmq_port} is already in use") return # Only support metadata file, fail fast for everything else @@ -77,7 +74,7 @@ def create_hnsw_embedding_server( meta = json.load(f) passages = PassageManager(meta["passage_sources"]) - print( + logger.info( f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata" ) @@ -86,7 +83,7 @@ def create_hnsw_embedding_server( context = zmq.Context() socket = context.socket(zmq.REP) socket.bind(f"tcp://*:{zmq_port}") - print(f"HNSW ZMQ server listening on port {zmq_port}") + logger.info(f"HNSW ZMQ server listening on port {zmq_port}") socket.setsockopt(zmq.RCVTIMEO, 300000) socket.setsockopt(zmq.SNDTIMEO, 300000) @@ -94,7 +91,7 @@ def create_hnsw_embedding_server( while True: try: message_bytes = socket.recv() - print(f"Received ZMQ request of size {len(message_bytes)} bytes") + logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes") e2e_start = time.time() request_payload = msgpack.unpackb(message_bytes) @@ -131,8 +128,8 @@ def create_hnsw_embedding_server( query_vector = np.array(request_payload[1], dtype=np.float32) logger.debug("Distance calculation request received") - print(f" Node IDs: {node_ids}") - print(f" Query vector dim: {len(query_vector)}") + logger.debug(f" Node IDs: {node_ids}") + logger.debug(f" Query vector dim: {len(query_vector)}") # Get embeddings for node IDs texts = [] @@ -142,20 +139,20 @@ def create_hnsw_embedding_server( txt = passage_data["text"] texts.append(txt) except KeyError: - print(f"ERROR: Passage ID {nid} not found") + logger.error(f"Passage ID {nid} not found") raise RuntimeError( f"FATAL: Passage with ID {nid} not found" ) except Exception as e: - print(f"ERROR: Exception looking up passage ID {nid}: {e}") + logger.error(f"Exception looking up passage ID {nid}: {e}") raise # Process embeddings embeddings = compute_embeddings( texts, model_name, mode=embedding_mode ) - print( - f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" ) # Calculate distances @@ -170,11 +167,15 @@ def create_hnsw_embedding_server( response_bytes = msgpack.packb( [response_payload], use_single_float=True ) - print(f"Sending distance response with {len(distances)} distances") + logger.debug( + f"Sending distance response with {len(distances)} distances" + ) socket.send(response_bytes) e2e_end = time.time() - print(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s") + logger.info( + f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s" + ) continue # Standard embedding request (passage ID lookup) @@ -183,14 +184,14 @@ def create_hnsw_embedding_server( or len(request_payload) != 1 or not isinstance(request_payload[0], list) ): - print( - f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}" + logger.error( + f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}" ) socket.send(msgpack.packb([[], []])) continue node_ids = request_payload[0] - print(f"Request for {len(node_ids)} node embeddings") + logger.debug(f"Request for {len(node_ids)} node embeddings") # Look up texts by node IDs texts = [] @@ -206,19 +207,19 @@ def create_hnsw_embedding_server( except KeyError: raise RuntimeError(f"FATAL: Passage with ID {nid} not found") except Exception as e: - print(f"ERROR: Exception looking up passage ID {nid}: {e}") + logger.error(f"Exception looking up passage ID {nid}: {e}") raise # Process embeddings embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) - print( - f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" + logger.info( + f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" ) # Serialization and response if np.isnan(embeddings).any() or np.isinf(embeddings).any(): - print( - f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}" + logger.error( + f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..." ) assert False @@ -239,7 +240,7 @@ def create_hnsw_embedding_server( logger.debug("ZMQ socket timeout, continuing to listen") continue except Exception as e: - print(f"Error in ZMQ server loop: {e}") + logger.error(f"Error in ZMQ server loop: {e}") import traceback traceback.print_exc() @@ -247,14 +248,14 @@ def create_hnsw_embedding_server( zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True) zmq_thread.start() - print(f"Started HNSW ZMQ server thread on port {zmq_port}") + logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}") # Keep the main thread alive try: while True: time.sleep(1) except KeyboardInterrupt: - print("HNSW Server shutting down...") + logger.info("HNSW Server shutting down...") return diff --git a/packages/leann-backend-hnsw/third_party/faiss b/packages/leann-backend-hnsw/third_party/faiss index 2547df4..ff22e2c 160000 --- a/packages/leann-backend-hnsw/third_party/faiss +++ b/packages/leann-backend-hnsw/third_party/faiss @@ -1 +1 @@ -Subproject commit 2547df4377ae097e2eabc9b019c15135b1fea2b4 +Subproject commit ff22e2c86be1784c760265abe146b1ab0db90ebe diff --git a/packages/leann-backend-hnsw/third_party/msgpack-c b/packages/leann-backend-hnsw/third_party/msgpack-c index 9b801f0..a0b2ec0 160000 --- a/packages/leann-backend-hnsw/third_party/msgpack-c +++ b/packages/leann-backend-hnsw/third_party/msgpack-c @@ -1 +1 @@ -Subproject commit 9b801f087ab7434f2ab1ab3c0f48a966c19d3b70 +Subproject commit a0b2ec09da4bd823e40fa591221713951d4ec995 diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index ff41912..879708c 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -113,7 +113,7 @@ class PassageManager: for source in passage_sources: assert source["type"] == "jsonl", "only jsonl is supported" passage_file = source["path"] - index_file = source["index_path"] + index_file = source["index_path"] # .idx file if not Path(index_file).exists(): raise FileNotFoundError(f"Passage index file not found: {index_file}") with open(index_file, "rb") as f: diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 3b30798..38dde43 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -6,7 +6,7 @@ Preserves all optimization parameters to ensure performance import numpy as np import torch -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any import logging logger = logging.getLogger(__name__) @@ -16,7 +16,10 @@ _model_cache: Dict[str, Any] = {} def compute_embeddings( - texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False + texts: List[str], + model_name: str, + mode: str = "sentence-transformers", + is_build: bool = False, ) -> np.ndarray: """ Unified embedding computation entry point @@ -30,7 +33,9 @@ def compute_embeddings( Normalized embeddings array, shape: (len(texts), embedding_dim) """ if mode == "sentence-transformers": - return compute_embeddings_sentence_transformers(texts, model_name, is_build=is_build) + return compute_embeddings_sentence_transformers( + texts, model_name, is_build=is_build + ) elif mode == "openai": return compute_embeddings_openai(texts, model_name) elif mode == "mlx": @@ -65,7 +70,7 @@ def compute_embeddings_sentence_transformers( # Create cache key cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}" - + # Check if model is already cached if cache_key in _model_cache: print(f"INFO: Using cached model: {model_name}") diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 6a44160..9ef7c78 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -1,14 +1,22 @@ -import threading import time import atexit import socket import subprocess import sys +import os +import logging from pathlib import Path from typing import Optional -import select import psutil +# Set up logging based on environment variable +LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() +logging.basicConfig( + level=getattr(logging, LOG_LEVEL, logging.INFO), + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + def _check_port(port: int) -> bool: """Check if a port is in use""" @@ -36,11 +44,11 @@ def _check_process_matches_config( cmdline, port, expected_model, expected_passages_file ) - print(f"DEBUG: No process found listening on port {port}") + logger.debug(f"No process found listening on port {port}") return False except Exception as e: - print(f"WARNING: Could not check process on port {port}: {e}") + logger.warning(f"Could not check process on port {port}: {e}") return False @@ -61,7 +69,7 @@ def _check_cmdline_matches_config( ) -> bool: """Check if command line matches our expected configuration.""" cmdline_str = " ".join(cmdline) - print(f"DEBUG: Found process on port {port}: {cmdline_str}") + logger.debug(f"Found process on port {port}: {cmdline_str}") # Check if it's our embedding server is_embedding_server = any( @@ -74,7 +82,7 @@ def _check_cmdline_matches_config( ) if not is_embedding_server: - print(f"DEBUG: Process on port {port} is not our embedding server") + logger.debug(f"Process on port {port} is not our embedding server") return False # Check model name @@ -84,8 +92,8 @@ def _check_cmdline_matches_config( passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file) result = model_matches and passages_matches - print( - f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}" + logger.debug( + f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}" ) return result @@ -132,10 +140,10 @@ def _find_compatible_port_or_next_available( # Port is in use, check if it's compatible if _check_process_matches_config(port, model_name, passages_file): - print(f"✅ Found compatible server on port {port}") + logger.info(f"Found compatible server on port {port}") return port, True else: - print(f"⚠️ Port {port} has incompatible server, trying next port...") + logger.info(f"Port {port} has incompatible server, trying next port...") raise RuntimeError( f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}" @@ -194,17 +202,17 @@ class EmbeddingServerManager: port, model_name, passages_file ) except RuntimeError as e: - print(f"❌ {e}") + logger.error(str(e)) return False, port if is_compatible: - print(f"✅ Using existing compatible server on port {actual_port}") + logger.info(f"Using existing compatible server on port {actual_port}") self.server_port = actual_port self.server_process = None # We don't own this process return True, actual_port if actual_port != port: - print(f"⚠️ Using port {actual_port} instead of {port}") + logger.info(f"Using port {actual_port} instead of {port}") # Start new server return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) @@ -221,19 +229,21 @@ class EmbeddingServerManager: return False if _check_process_matches_config(self.server_port, model_name, passages_file): - print( - f"✅ Existing server process (PID {self.server_process.pid}) is compatible" + logger.info( + f"Existing server process (PID {self.server_process.pid}) is compatible" ) return True - print("⚠️ Existing server process is incompatible. Should start a new server.") + logger.info( + "Existing server process is incompatible. Should start a new server." + ) return False def _start_new_server( self, port: int, model_name: str, embedding_mode: str, **kwargs ) -> tuple[bool, int]: """Start a new embedding server on the given port.""" - print(f"INFO: Starting embedding server on port {port}...") + logger.info(f"Starting embedding server on port {port}...") command = self._build_server_command(port, model_name, embedding_mode, **kwargs) @@ -241,7 +251,7 @@ class EmbeddingServerManager: self._launch_server_process(command, port) return self._wait_for_server_ready(port) except Exception as e: - print(f"❌ ERROR: Failed to start embedding server: {e}") + logger.error(f"Failed to start embedding server: {e}") return False, port def _build_server_command( @@ -268,20 +278,18 @@ class EmbeddingServerManager: def _launch_server_process(self, command: list, port: int) -> None: """Launch the server process.""" project_root = Path(__file__).parent.parent.parent.parent.parent - print(f"INFO: Command: {' '.join(command)}") + logger.info(f"Command: {' '.join(command)}") + # Let server output go directly to console + # The server will respect LEANN_LOG_LEVEL environment variable self.server_process = subprocess.Popen( command, cwd=project_root, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - encoding="utf-8", - bufsize=1, - universal_newlines=True, + stdout=None, # Direct to console + stderr=None, # Direct to console ) self.server_port = port - print(f"INFO: Server process started with PID: {self.server_process.pid}") + logger.info(f"Server process started with PID: {self.server_process.pid}") # Register atexit callback only when we actually start a process if not self._atexit_registered: @@ -294,49 +302,19 @@ class EmbeddingServerManager: max_wait, wait_interval = 120, 0.5 for _ in range(int(max_wait / wait_interval)): if _check_port(port): - print("✅ Embedding server is ready!") - threading.Thread(target=self._log_monitor, daemon=True).start() + logger.info("Embedding server is ready!") return True, port - if self.server_process.poll() is not None: - print("❌ ERROR: Server terminated during startup.") - self._print_recent_output() + if self.server_process and self.server_process.poll() is not None: + logger.error("Server terminated during startup.") return False, port time.sleep(wait_interval) - print(f"❌ ERROR: Server failed to start within {max_wait} seconds.") + logger.error(f"Server failed to start within {max_wait} seconds.") self.stop_server() return False, port - def _print_recent_output(self): - """Print any recent output from the server process.""" - if not self.server_process or not self.server_process.stdout: - return - try: - if select.select([self.server_process.stdout], [], [], 0)[0]: - output = self.server_process.stdout.read() - if output: - print(f"[{self.backend_module_name} OUTPUT]: {output}") - except Exception as e: - print(f"Error reading server output: {e}") - - def _log_monitor(self): - """Monitors and prints the server's stdout and stderr.""" - if not self.server_process: - return - try: - if self.server_process.stdout: - while True: - line = self.server_process.stdout.readline() - if not line: - break - print( - f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True - ) - except Exception as e: - print(f"Log monitor error: {e}") - def stop_server(self): """Stops the embedding server process if it's running.""" if not self.server_process: @@ -347,17 +325,17 @@ class EmbeddingServerManager: self.server_process = None return - print( - f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..." + logger.info( + f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..." ) self.server_process.terminate() try: self.server_process.wait(timeout=5) - print(f"INFO: Server process {self.server_process.pid} terminated.") + logger.info(f"Server process {self.server_process.pid} terminated.") except subprocess.TimeoutExpired: - print( - f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it." + logger.warning( + f"Server process {self.server_process.pid} did not terminate gracefully, killing it." ) self.server_process.kill()