refactor: logs
This commit is contained in:
@@ -11,13 +11,10 @@ import numpy as np
|
|||||||
import msgpack
|
import msgpack
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Optional, Union
|
from typing import Optional
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
RED = "\033[91m"
|
|
||||||
RESET = "\033[0m"
|
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -38,8 +35,8 @@ def create_hnsw_embedding_server(
|
|||||||
Create and start a ZMQ-based embedding server for HNSW backend.
|
Create and start a ZMQ-based embedding server for HNSW backend.
|
||||||
Simplified version using unified embedding computation module.
|
Simplified version using unified embedding computation module.
|
||||||
"""
|
"""
|
||||||
print(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
logger.info(f"Starting HNSW server on port {zmq_port} with model {model_name}")
|
||||||
print(f"Using embedding mode: {embedding_mode}")
|
logger.info(f"Using embedding mode: {embedding_mode}")
|
||||||
|
|
||||||
# Add leann-core to path for unified embedding computation
|
# Add leann-core to path for unified embedding computation
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
@@ -50,9 +47,9 @@ def create_hnsw_embedding_server(
|
|||||||
from leann.embedding_compute import compute_embeddings
|
from leann.embedding_compute import compute_embeddings
|
||||||
from leann.api import PassageManager
|
from leann.api import PassageManager
|
||||||
|
|
||||||
print("Successfully imported unified embedding computation module")
|
logger.info("Successfully imported unified embedding computation module")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"ERROR: Failed to import embedding computation module: {e}")
|
logger.error(f"Failed to import embedding computation module: {e}")
|
||||||
return
|
return
|
||||||
finally:
|
finally:
|
||||||
sys.path.pop(0)
|
sys.path.pop(0)
|
||||||
@@ -65,7 +62,7 @@ def create_hnsw_embedding_server(
|
|||||||
return s.connect_ex(("localhost", port)) == 0
|
return s.connect_ex(("localhost", port)) == 0
|
||||||
|
|
||||||
if check_port(zmq_port):
|
if check_port(zmq_port):
|
||||||
print(f"{RED}Port {zmq_port} is already in use{RESET}")
|
logger.error(f"Port {zmq_port} is already in use")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Only support metadata file, fail fast for everything else
|
# Only support metadata file, fail fast for everything else
|
||||||
@@ -77,7 +74,7 @@ def create_hnsw_embedding_server(
|
|||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
|
|
||||||
passages = PassageManager(meta["passage_sources"])
|
passages = PassageManager(meta["passage_sources"])
|
||||||
print(
|
logger.info(
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -86,7 +83,7 @@ def create_hnsw_embedding_server(
|
|||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
socket = context.socket(zmq.REP)
|
socket = context.socket(zmq.REP)
|
||||||
socket.bind(f"tcp://*:{zmq_port}")
|
socket.bind(f"tcp://*:{zmq_port}")
|
||||||
print(f"HNSW ZMQ server listening on port {zmq_port}")
|
logger.info(f"HNSW ZMQ server listening on port {zmq_port}")
|
||||||
|
|
||||||
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
socket.setsockopt(zmq.RCVTIMEO, 300000)
|
||||||
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
socket.setsockopt(zmq.SNDTIMEO, 300000)
|
||||||
@@ -94,7 +91,7 @@ def create_hnsw_embedding_server(
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
message_bytes = socket.recv()
|
message_bytes = socket.recv()
|
||||||
print(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
logger.debug(f"Received ZMQ request of size {len(message_bytes)} bytes")
|
||||||
|
|
||||||
e2e_start = time.time()
|
e2e_start = time.time()
|
||||||
request_payload = msgpack.unpackb(message_bytes)
|
request_payload = msgpack.unpackb(message_bytes)
|
||||||
@@ -131,8 +128,8 @@ def create_hnsw_embedding_server(
|
|||||||
query_vector = np.array(request_payload[1], dtype=np.float32)
|
query_vector = np.array(request_payload[1], dtype=np.float32)
|
||||||
|
|
||||||
logger.debug("Distance calculation request received")
|
logger.debug("Distance calculation request received")
|
||||||
print(f" Node IDs: {node_ids}")
|
logger.debug(f" Node IDs: {node_ids}")
|
||||||
print(f" Query vector dim: {len(query_vector)}")
|
logger.debug(f" Query vector dim: {len(query_vector)}")
|
||||||
|
|
||||||
# Get embeddings for node IDs
|
# Get embeddings for node IDs
|
||||||
texts = []
|
texts = []
|
||||||
@@ -142,20 +139,20 @@ def create_hnsw_embedding_server(
|
|||||||
txt = passage_data["text"]
|
txt = passage_data["text"]
|
||||||
texts.append(txt)
|
texts.append(txt)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(f"ERROR: Passage ID {nid} not found")
|
logger.error(f"Passage ID {nid} not found")
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"FATAL: Passage with ID {nid} not found"
|
f"FATAL: Passage with ID {nid} not found"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Process embeddings
|
# Process embeddings
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts, model_name, mode=embedding_mode
|
texts, model_name, mode=embedding_mode
|
||||||
)
|
)
|
||||||
print(
|
logger.info(
|
||||||
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate distances
|
# Calculate distances
|
||||||
@@ -170,11 +167,15 @@ def create_hnsw_embedding_server(
|
|||||||
response_bytes = msgpack.packb(
|
response_bytes = msgpack.packb(
|
||||||
[response_payload], use_single_float=True
|
[response_payload], use_single_float=True
|
||||||
)
|
)
|
||||||
print(f"Sending distance response with {len(distances)} distances")
|
logger.debug(
|
||||||
|
f"Sending distance response with {len(distances)} distances"
|
||||||
|
)
|
||||||
|
|
||||||
socket.send(response_bytes)
|
socket.send(response_bytes)
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
print(f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(
|
||||||
|
f"⏱️ Distance calculation E2E time: {e2e_end - e2e_start:.6f}s"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Standard embedding request (passage ID lookup)
|
# Standard embedding request (passage ID lookup)
|
||||||
@@ -183,14 +184,14 @@ def create_hnsw_embedding_server(
|
|||||||
or len(request_payload) != 1
|
or len(request_payload) != 1
|
||||||
or not isinstance(request_payload[0], list)
|
or not isinstance(request_payload[0], list)
|
||||||
):
|
):
|
||||||
print(
|
logger.error(
|
||||||
f"Error: Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
f"Invalid MessagePack request format. Expected [[ids...]] or [texts...], got: {type(request_payload)}"
|
||||||
)
|
)
|
||||||
socket.send(msgpack.packb([[], []]))
|
socket.send(msgpack.packb([[], []]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
node_ids = request_payload[0]
|
node_ids = request_payload[0]
|
||||||
print(f"Request for {len(node_ids)} node embeddings")
|
logger.debug(f"Request for {len(node_ids)} node embeddings")
|
||||||
|
|
||||||
# Look up texts by node IDs
|
# Look up texts by node IDs
|
||||||
texts = []
|
texts = []
|
||||||
@@ -206,19 +207,19 @@ def create_hnsw_embedding_server(
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
raise RuntimeError(f"FATAL: Passage with ID {nid} not found")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR: Exception looking up passage ID {nid}: {e}")
|
logger.error(f"Exception looking up passage ID {nid}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Process embeddings
|
# Process embeddings
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
||||||
print(
|
logger.info(
|
||||||
f"INFO: Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Serialization and response
|
# Serialization and response
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
print(
|
logger.error(
|
||||||
f"{RED}!!! ERROR: NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}...{RESET}"
|
f"NaN or Inf detected in embeddings! Requested IDs: {node_ids[:5]}..."
|
||||||
)
|
)
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
@@ -239,7 +240,7 @@ def create_hnsw_embedding_server(
|
|||||||
logger.debug("ZMQ socket timeout, continuing to listen")
|
logger.debug("ZMQ socket timeout, continuing to listen")
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in ZMQ server loop: {e}")
|
logger.error(f"Error in ZMQ server loop: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@@ -247,14 +248,14 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
zmq_thread = threading.Thread(target=zmq_server_thread, daemon=True)
|
||||||
zmq_thread.start()
|
zmq_thread.start()
|
||||||
print(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
logger.info(f"Started HNSW ZMQ server thread on port {zmq_port}")
|
||||||
|
|
||||||
# Keep the main thread alive
|
# Keep the main thread alive
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("HNSW Server shutting down...")
|
logger.info("HNSW Server shutting down...")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 2547df4377...ff22e2c86b
Submodule packages/leann-backend-hnsw/third_party/msgpack-c updated: 9b801f087a...a0b2ec09da
@@ -113,7 +113,7 @@ class PassageManager:
|
|||||||
for source in passage_sources:
|
for source in passage_sources:
|
||||||
assert source["type"] == "jsonl", "only jsonl is supported"
|
assert source["type"] == "jsonl", "only jsonl is supported"
|
||||||
passage_file = source["path"]
|
passage_file = source["path"]
|
||||||
index_file = source["index_path"]
|
index_file = source["index_path"] # .idx file
|
||||||
if not Path(index_file).exists():
|
if not Path(index_file).exists():
|
||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
with open(index_file, "rb") as f:
|
with open(index_file, "rb") as f:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -16,7 +16,10 @@ _model_cache: Dict[str, Any] = {}
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings(
|
def compute_embeddings(
|
||||||
texts: List[str], model_name: str, mode: str = "sentence-transformers",is_build: bool = False
|
texts: List[str],
|
||||||
|
model_name: str,
|
||||||
|
mode: str = "sentence-transformers",
|
||||||
|
is_build: bool = False,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -30,7 +33,9 @@ def compute_embeddings(
|
|||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(texts, model_name, is_build=is_build)
|
return compute_embeddings_sentence_transformers(
|
||||||
|
texts, model_name, is_build=is_build
|
||||||
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(texts, model_name)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
@@ -65,7 +70,7 @@ def compute_embeddings_sentence_transformers(
|
|||||||
|
|
||||||
# Create cache key
|
# Create cache key
|
||||||
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}"
|
cache_key = f"sentence_transformers_{model_name}_{device}_{use_fp16}"
|
||||||
|
|
||||||
# Check if model is already cached
|
# Check if model is already cached
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
print(f"INFO: Using cached model: {model_name}")
|
print(f"INFO: Using cached model: {model_name}")
|
||||||
|
|||||||
@@ -1,14 +1,22 @@
|
|||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
import atexit
|
import atexit
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import select
|
|
||||||
import psutil
|
import psutil
|
||||||
|
|
||||||
|
# Set up logging based on environment variable
|
||||||
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, LOG_LEVEL, logging.INFO),
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _check_port(port: int) -> bool:
|
def _check_port(port: int) -> bool:
|
||||||
"""Check if a port is in use"""
|
"""Check if a port is in use"""
|
||||||
@@ -36,11 +44,11 @@ def _check_process_matches_config(
|
|||||||
cmdline, port, expected_model, expected_passages_file
|
cmdline, port, expected_model, expected_passages_file
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"DEBUG: No process found listening on port {port}")
|
logger.debug(f"No process found listening on port {port}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"WARNING: Could not check process on port {port}: {e}")
|
logger.warning(f"Could not check process on port {port}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -61,7 +69,7 @@ def _check_cmdline_matches_config(
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if command line matches our expected configuration."""
|
"""Check if command line matches our expected configuration."""
|
||||||
cmdline_str = " ".join(cmdline)
|
cmdline_str = " ".join(cmdline)
|
||||||
print(f"DEBUG: Found process on port {port}: {cmdline_str}")
|
logger.debug(f"Found process on port {port}: {cmdline_str}")
|
||||||
|
|
||||||
# Check if it's our embedding server
|
# Check if it's our embedding server
|
||||||
is_embedding_server = any(
|
is_embedding_server = any(
|
||||||
@@ -74,7 +82,7 @@ def _check_cmdline_matches_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not is_embedding_server:
|
if not is_embedding_server:
|
||||||
print(f"DEBUG: Process on port {port} is not our embedding server")
|
logger.debug(f"Process on port {port} is not our embedding server")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check model name
|
# Check model name
|
||||||
@@ -84,8 +92,8 @@ def _check_cmdline_matches_config(
|
|||||||
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file)
|
||||||
|
|
||||||
result = model_matches and passages_matches
|
result = model_matches and passages_matches
|
||||||
print(
|
logger.debug(
|
||||||
f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
f"model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}"
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -132,10 +140,10 @@ def _find_compatible_port_or_next_available(
|
|||||||
|
|
||||||
# Port is in use, check if it's compatible
|
# Port is in use, check if it's compatible
|
||||||
if _check_process_matches_config(port, model_name, passages_file):
|
if _check_process_matches_config(port, model_name, passages_file):
|
||||||
print(f"✅ Found compatible server on port {port}")
|
logger.info(f"Found compatible server on port {port}")
|
||||||
return port, True
|
return port, True
|
||||||
else:
|
else:
|
||||||
print(f"⚠️ Port {port} has incompatible server, trying next port...")
|
logger.info(f"Port {port} has incompatible server, trying next port...")
|
||||||
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}"
|
||||||
@@ -194,17 +202,17 @@ class EmbeddingServerManager:
|
|||||||
port, model_name, passages_file
|
port, model_name, passages_file
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(f"❌ {e}")
|
logger.error(str(e))
|
||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
if is_compatible:
|
if is_compatible:
|
||||||
print(f"✅ Using existing compatible server on port {actual_port}")
|
logger.info(f"Using existing compatible server on port {actual_port}")
|
||||||
self.server_port = actual_port
|
self.server_port = actual_port
|
||||||
self.server_process = None # We don't own this process
|
self.server_process = None # We don't own this process
|
||||||
return True, actual_port
|
return True, actual_port
|
||||||
|
|
||||||
if actual_port != port:
|
if actual_port != port:
|
||||||
print(f"⚠️ Using port {actual_port} instead of {port}")
|
logger.info(f"Using port {actual_port} instead of {port}")
|
||||||
|
|
||||||
# Start new server
|
# Start new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
||||||
@@ -221,19 +229,21 @@ class EmbeddingServerManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
if _check_process_matches_config(self.server_port, model_name, passages_file):
|
||||||
print(
|
logger.info(
|
||||||
f"✅ Existing server process (PID {self.server_process.pid}) is compatible"
|
f"Existing server process (PID {self.server_process.pid}) is compatible"
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
print("⚠️ Existing server process is incompatible. Should start a new server.")
|
logger.info(
|
||||||
|
"Existing server process is incompatible. Should start a new server."
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start a new embedding server on the given port."""
|
"""Start a new embedding server on the given port."""
|
||||||
print(f"INFO: Starting embedding server on port {port}...")
|
logger.info(f"Starting embedding server on port {port}...")
|
||||||
|
|
||||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
@@ -241,7 +251,7 @@ class EmbeddingServerManager:
|
|||||||
self._launch_server_process(command, port)
|
self._launch_server_process(command, port)
|
||||||
return self._wait_for_server_ready(port)
|
return self._wait_for_server_ready(port)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ ERROR: Failed to start embedding server: {e}")
|
logger.error(f"Failed to start embedding server: {e}")
|
||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
def _build_server_command(
|
def _build_server_command(
|
||||||
@@ -268,20 +278,18 @@ class EmbeddingServerManager:
|
|||||||
def _launch_server_process(self, command: list, port: int) -> None:
|
def _launch_server_process(self, command: list, port: int) -> None:
|
||||||
"""Launch the server process."""
|
"""Launch the server process."""
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
print(f"INFO: Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
|
|
||||||
|
# Let server output go directly to console
|
||||||
|
# The server will respect LEANN_LOG_LEVEL environment variable
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=subprocess.PIPE,
|
stdout=None, # Direct to console
|
||||||
stderr=subprocess.STDOUT,
|
stderr=None, # Direct to console
|
||||||
text=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
bufsize=1,
|
|
||||||
universal_newlines=True,
|
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
print(f"INFO: Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback only when we actually start a process
|
# Register atexit callback only when we actually start a process
|
||||||
if not self._atexit_registered:
|
if not self._atexit_registered:
|
||||||
@@ -294,49 +302,19 @@ class EmbeddingServerManager:
|
|||||||
max_wait, wait_interval = 120, 0.5
|
max_wait, wait_interval = 120, 0.5
|
||||||
for _ in range(int(max_wait / wait_interval)):
|
for _ in range(int(max_wait / wait_interval)):
|
||||||
if _check_port(port):
|
if _check_port(port):
|
||||||
print("✅ Embedding server is ready!")
|
logger.info("Embedding server is ready!")
|
||||||
threading.Thread(target=self._log_monitor, daemon=True).start()
|
|
||||||
return True, port
|
return True, port
|
||||||
|
|
||||||
if self.server_process.poll() is not None:
|
if self.server_process and self.server_process.poll() is not None:
|
||||||
print("❌ ERROR: Server terminated during startup.")
|
logger.error("Server terminated during startup.")
|
||||||
self._print_recent_output()
|
|
||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
time.sleep(wait_interval)
|
time.sleep(wait_interval)
|
||||||
|
|
||||||
print(f"❌ ERROR: Server failed to start within {max_wait} seconds.")
|
logger.error(f"Server failed to start within {max_wait} seconds.")
|
||||||
self.stop_server()
|
self.stop_server()
|
||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
def _print_recent_output(self):
|
|
||||||
"""Print any recent output from the server process."""
|
|
||||||
if not self.server_process or not self.server_process.stdout:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if select.select([self.server_process.stdout], [], [], 0)[0]:
|
|
||||||
output = self.server_process.stdout.read()
|
|
||||||
if output:
|
|
||||||
print(f"[{self.backend_module_name} OUTPUT]: {output}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading server output: {e}")
|
|
||||||
|
|
||||||
def _log_monitor(self):
|
|
||||||
"""Monitors and prints the server's stdout and stderr."""
|
|
||||||
if not self.server_process:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
if self.server_process.stdout:
|
|
||||||
while True:
|
|
||||||
line = self.server_process.stdout.readline()
|
|
||||||
if not line:
|
|
||||||
break
|
|
||||||
print(
|
|
||||||
f"[{self.backend_module_name} LOG]: {line.strip()}", flush=True
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Log monitor error: {e}")
|
|
||||||
|
|
||||||
def stop_server(self):
|
def stop_server(self):
|
||||||
"""Stops the embedding server process if it's running."""
|
"""Stops the embedding server process if it's running."""
|
||||||
if not self.server_process:
|
if not self.server_process:
|
||||||
@@ -347,17 +325,17 @@ class EmbeddingServerManager:
|
|||||||
self.server_process = None
|
self.server_process = None
|
||||||
return
|
return
|
||||||
|
|
||||||
print(
|
logger.info(
|
||||||
f"INFO: Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
f"Terminating server process (PID: {self.server_process.pid}) for backend {self.backend_module_name}..."
|
||||||
)
|
)
|
||||||
self.server_process.terminate()
|
self.server_process.terminate()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.server_process.wait(timeout=5)
|
self.server_process.wait(timeout=5)
|
||||||
print(f"INFO: Server process {self.server_process.pid} terminated.")
|
logger.info(f"Server process {self.server_process.pid} terminated.")
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
print(
|
logger.warning(
|
||||||
f"WARNING: Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
f"Server process {self.server_process.pid} did not terminate gracefully, killing it."
|
||||||
)
|
)
|
||||||
self.server_process.kill()
|
self.server_process.kill()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user