diff --git a/packages/leann-backend-diskann/leann_backend_diskann/__init__.py b/packages/leann-backend-diskann/leann_backend_diskann/__init__.py index e69de29..08137b0 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/__init__.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/__init__.py @@ -0,0 +1 @@ +from . import diskann_backend \ No newline at end of file diff --git a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py index 8af4046..d59ef69 100644 --- a/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py +++ b/packages/leann-backend-diskann/leann_backend_diskann/diskann_backend.py @@ -12,6 +12,7 @@ import socket import subprocess import sys +from leann.embedding_server_manager import EmbeddingServerManager from leann.registry import register_backend from leann.interface import ( LeannBackendFactoryInterface, @@ -42,96 +43,6 @@ def _write_vectors_to_bin(data: np.ndarray, file_path: str): f.write(struct.pack('I', dim)) f.write(data.tobytes()) -def _check_port(port: int) -> bool: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 - -class EmbeddingServerManager: - def __init__(self): - self.server_process = None - self.server_port = None - atexit.register(self.stop_server) - - def start_server(self, port=5555, model_name="sentence-transformers/all-mpnet-base-v2"): - if self.server_process and self.server_process.poll() is None: - print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})") - return True - - # 检查端口是否已被其他无关进程占用 - if _check_port(port): - print(f"WARNING: Port {port} is already in use. Assuming an external server is running and connecting to it.") - return True - - print(f"INFO: Starting session-level embedding server as a background process...") - - try: - command = [ - sys.executable, - "-m", "packages.leann-backend-diskann.leann_backend_diskann.embedding_server", - "--zmq-port", str(port), - "--model-name", model_name - ] - project_root = Path(__file__).parent.parent.parent.parent - print(f"INFO: Running command from project root: {project_root}") - self.server_process = subprocess.Popen( - command, - cwd=project_root, - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE, - text=True, - encoding='utf-8' - ) - self.server_port = port - print(f"INFO: Server process started with PID: {self.server_process.pid}") - - max_wait, wait_interval = 30, 0.5 - for _ in range(int(max_wait / wait_interval)): - if _check_port(port): - print(f"✅ Embedding server is up and ready for this session.") - log_thread = threading.Thread(target=self._log_monitor, daemon=True) - log_thread.start() - return True - if self.server_process.poll() is not None: - print("❌ ERROR: Server process terminated unexpectedly during startup.") - self._log_monitor() - return False - time.sleep(wait_interval) - - print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.") - self.stop_server() - return False - - except Exception as e: - print(f"❌ ERROR: Failed to start embedding server process: {e}") - return False - - def _log_monitor(self): - if not self.server_process: - return - try: - if self.server_process.stdout: - for line in iter(self.server_process.stdout.readline, ''): - print(f"[EmbeddingServer LOG]: {line.strip()}") - self.server_process.stdout.close() - if self.server_process.stderr: - for line in iter(self.server_process.stderr.readline, ''): - print(f"[EmbeddingServer ERROR]: {line.strip()}") - self.server_process.stderr.close() - except Exception as e: - print(f"Log monitor error: {e}") - - def stop_server(self): - if self.server_process and self.server_process.poll() is None: - print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...") - self.server_process.terminate() - try: - self.server_process.wait(timeout=5) - print("INFO: Server process terminated.") - except subprocess.TimeoutExpired: - print("WARNING: Server process did not terminate gracefully, killing it.") - self.server_process.kill() - self.server_process = None - @register_backend("diskann") class DiskannBackend(LeannBackendFactoryInterface): @staticmethod @@ -143,16 +54,13 @@ class DiskannBackend(LeannBackendFactoryInterface): path = Path(index_path) meta_path = path.parent / f"{path.name}.meta.json" if not meta_path.exists(): - raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.") - + raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.") + with open(meta_path, 'r') as f: meta = json.load(f) - dimensions = meta.get("dimensions") - if not dimensions: - raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.") - - kwargs['dimensions'] = dimensions + # Pass essential metadata to the searcher + kwargs['meta'] = meta return DiskannSearcher(index_path, **kwargs) class DiskannBuilder(LeannBackendBuilderInterface): @@ -215,19 +123,29 @@ class DiskannBuilder(LeannBackendBuilderInterface): class DiskannSearcher(LeannBackendSearcherInterface): def __init__(self, index_path: str, **kwargs): + self.meta = kwargs.get("meta", {}) + if not self.meta: + raise ValueError("DiskannSearcher requires metadata from .meta.json.") + + dimensions = self.meta.get("dimensions") + if not dimensions: + raise ValueError("Dimensions not found in Leann metadata.") + + self.distance_metric = self.meta.get("distance_metric", "mips").lower() + metric_enum = METRIC_MAP.get(self.distance_metric) + if metric_enum is None: + raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") + + self.embedding_model = self.meta.get("embedding_model") + if not self.embedding_model: + print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") + path = Path(index_path) index_dir = path.parent index_prefix = path.stem - metric_str = kwargs.get("distance_metric", "mips").lower() - metric_enum = METRIC_MAP.get(metric_str) - if metric_enum is None: - raise ValueError(f"Unsupported distance_metric '{metric_str}'.") num_threads = kwargs.get("num_threads", 8) num_nodes_to_cache = kwargs.get("num_nodes_to_cache", 0) - dimensions = kwargs.get("dimensions") - if not dimensions: - raise ValueError("Vector dimension not provided to DiskannSearcher.") try: full_index_prefix = str(index_dir / index_prefix) @@ -235,7 +153,9 @@ class DiskannSearcher(LeannBackendSearcherInterface): metric_enum, full_index_prefix, num_threads, num_nodes_to_cache, 1, "", "" ) self.num_threads = num_threads - self.embedding_server_manager = EmbeddingServerManager() + self.embedding_server_manager = EmbeddingServerManager( + backend_module_name="leann_backend_diskann.embedding_server" + ) print("✅ DiskANN index loaded successfully.") except Exception as e: print(f"💥 ERROR: Failed to load DiskANN index. Exception: {e}") @@ -255,12 +175,20 @@ class DiskannSearcher(LeannBackendSearcherInterface): if recompute_beighbor_embeddings: print(f"INFO: DiskANN ZMQ mode enabled - ensuring embedding server is running") - zmq_port = kwargs.get("zmq_port", 6666) - embedding_model = kwargs.get("embedding_model", "sentence-transformers/all-mpnet-base-v2") + if not self.embedding_model: + raise ValueError("Cannot use recompute_beighbor_embeddings without 'embedding_model' in meta.json.") - if not self.embedding_server_manager.start_server(zmq_port, embedding_model): + zmq_port = kwargs.get("zmq_port", 6666) + + server_started = self.embedding_server_manager.start_server( + port=zmq_port, + model_name=self.embedding_model, + distance_metric=self.distance_metric + ) + + if not server_started: print(f"WARNING: Failed to start embedding server, falling back to PQ computation") - kwargs['recompute_beighbor_embeddings'] = False + recompute_beighbor_embeddings = False if query.dtype != np.float32: query = query.astype(np.float32) @@ -292,4 +220,4 @@ class DiskannSearcher(LeannBackendSearcherInterface): def __del__(self): if hasattr(self, 'embedding_server_manager'): - self.embedding_server_manager.stop_server() \ No newline at end of file + self.embedding_server_manager.stop_server() diff --git a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py index 7f21c84..2cad5fb 100644 --- a/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py +++ b/packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py @@ -12,6 +12,7 @@ import socket import subprocess import sys +from leann.embedding_server_manager import EmbeddingServerManager from .convert_to_csr import convert_hnsw_graph_to_csr from leann.registry import register_backend @@ -29,118 +30,6 @@ def get_metric_map(): "cosine": faiss.METRIC_INNER_PRODUCT, } -def _check_port(port: int) -> bool: - """Check if a port is in use""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(('localhost', port)) == 0 - -class HNSWEmbeddingServerManager: - """ - HNSW-specific embedding server manager that handles the lifecycle of the embedding server process. - Mirrors the DiskANN EmbeddingServerManager architecture. - """ - def __init__(self): - self.server_process = None - self.server_port = None - atexit.register(self.stop_server) - - def start_server(self, port=5556, model_name="sentence-transformers/all-mpnet-base-v2", passages_file=None, distance_metric="mips"): - """ - Start the HNSW embedding server process. - - Args: - port: ZMQ port for the server - model_name: Name of the embedding model to use - passages_file: Optional path to passages JSON file - distance_metric: The distance metric to use - """ - if self.server_process and self.server_process.poll() is None: - print(f"INFO: Reusing existing HNSW server process for this session (PID {self.server_process.pid})") - return True - - # Check if port is already in use - if _check_port(port): - print(f"WARNING: Port {port} is already in use. Assuming an external HNSW server is running and connecting to it.") - return True - - print(f"INFO: Starting session-level HNSW embedding server as a background process...") - - try: - command = [ - sys.executable, - "-m", "leann_backend_hnsw.hnsw_embedding_server", - "--zmq-port", str(port), - "--model-name", model_name, - "--distance-metric", distance_metric - ] - - if passages_file: - command.extend(["--passages-file", str(passages_file)]) - - project_root = Path(__file__).parent.parent.parent.parent - print(f"INFO: Running HNSW command from project root: {project_root}") - - self.server_process = subprocess.Popen( - command, - cwd=project_root, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - encoding='utf-8' - ) - self.server_port = port - print(f"INFO: HNSW server process started with PID: {self.server_process.pid}") - - max_wait, wait_interval = 30, 0.5 - for _ in range(int(max_wait / wait_interval)): - if _check_port(port): - print(f"✅ HNSW embedding server is up and ready for this session.") - log_thread = threading.Thread(target=self._log_monitor, daemon=True) - log_thread.start() - return True - if self.server_process.poll() is not None: - print("❌ ERROR: HNSW server process terminated unexpectedly during startup.") - self._log_monitor() - return False - time.sleep(wait_interval) - - print(f"❌ ERROR: HNSW server process failed to start listening within {max_wait} seconds.") - self.stop_server() - return False - - except Exception as e: - print(f"❌ ERROR: Failed to start HNSW embedding server process: {e}") - return False - - def _log_monitor(self): - """Monitor server logs""" - if not self.server_process: - return - try: - if self.server_process.stdout: - for line in iter(self.server_process.stdout.readline, ''): - print(f"[HNSWEmbeddingServer LOG]: {line.strip()}") - self.server_process.stdout.close() - if self.server_process.stderr: - for line in iter(self.server_process.stderr.readline, ''): - print(f"[HNSWEmbeddingServer ERROR]: {line.strip()}") - self.server_process.stderr.close() - except Exception as e: - print(f"HNSW Log monitor error: {e}") - - def stop_server(self): - """Stop the HNSW embedding server process""" - if self.server_process and self.server_process.poll() is None: - print(f"INFO: Terminating HNSW session server process (PID: {self.server_process.pid})...") - self.server_process.terminate() - try: - self.server_process.wait(timeout=5) - print("INFO: HNSW server process terminated.") - except subprocess.TimeoutExpired: - print("WARNING: HNSW server process did not terminate gracefully, killing it.") - self.server_process.kill() - self.server_process = None - @register_backend("hnsw") class HNSWBackend(LeannBackendFactoryInterface): @staticmethod @@ -152,16 +41,12 @@ class HNSWBackend(LeannBackendFactoryInterface): path = Path(index_path) meta_path = path.parent / f"{path.name}.meta.json" if not meta_path.exists(): - raise FileNotFoundError(f"Leann metadata file not found at {meta_path}. Cannot infer vector dimension for searcher.") + raise FileNotFoundError(f"Leann metadata file not found at {meta_path}.") with open(meta_path, 'r') as f: meta = json.load(f) - dimensions = meta.get("dimensions") - if not dimensions: - raise ValueError("Dimensions not found in Leann metadata. Please rebuild the index with a newer version of Leann.") - - kwargs['dimensions'] = dimensions + kwargs['meta'] = meta return HNSWSearcher(index_path, **kwargs) class HNSWBuilder(LeannBackendBuilderInterface): @@ -376,47 +261,49 @@ class HNSWSearcher(LeannBackendSearcherInterface): def __init__(self, index_path: str, **kwargs): from . import faiss - path = Path(index_path) - index_dir = path.parent - index_prefix = path.stem - - # Store configuration and paths for later use - self.config = kwargs.copy() - self.config["index_path"] = index_path - self.index_dir = index_dir - self.index_prefix = index_prefix - - metric_str = self.config.get("distance_metric", "mips").lower() - metric_enum = get_metric_map().get(metric_str) + self.meta = kwargs.get("meta", {}) + if not self.meta: + raise ValueError("HNSWSearcher requires metadata from .meta.json.") + + self.dimensions = self.meta.get("dimensions") + if not self.dimensions: + raise ValueError("Dimensions not found in Leann metadata.") + + self.distance_metric = self.meta.get("distance_metric", "mips").lower() + metric_enum = get_metric_map().get(self.distance_metric) if metric_enum is None: - raise ValueError(f"Unsupported distance_metric '{metric_str}'.") + raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") + + self.embedding_model = self.meta.get("embedding_model") + if not self.embedding_model: + print("WARNING: embedding_model not found in meta.json. Recompute will fail if attempted.") + + path = Path(index_path) + self.index_dir = path.parent + self.index_prefix = path.stem - dimensions = self.config.get("dimensions") - if not dimensions: - raise ValueError("Vector dimension not provided to HNSWSearcher.") - - index_file = index_dir / f"{index_prefix}.index" + index_file = self.index_dir / f"{self.index_prefix}.index" if not index_file.exists(): raise FileNotFoundError(f"HNSW index file not found at {index_file}") self.is_compact, self.is_pruned = self._get_index_storage_status(index_file) # Validate configuration constraints - if not self.is_compact and self.config.get("is_skip_neighbors", False): + if not self.is_compact and kwargs.get("is_skip_neighbors", False): raise ValueError("is_skip_neighbors can only be used with is_compact=True") - if self.config.get("is_recompute", False) and self.config.get("external_storage_path"): + if kwargs.get("is_recompute", False) and kwargs.get("external_storage_path"): raise ValueError("Cannot use both is_recompute and external_storage_path simultaneously") hnsw_config = faiss.HNSWIndexConfig() hnsw_config.is_compact = self.is_compact # Apply additional configuration options with strict validation - hnsw_config.is_skip_neighbors = self.config.get("is_skip_neighbors", False) - hnsw_config.is_recompute = self.is_pruned or self.config.get("is_recompute", False) - hnsw_config.disk_cache_ratio = self.config.get("disk_cache_ratio", 0.0) - hnsw_config.external_storage_path = self.config.get("external_storage_path") - hnsw_config.zmq_port = self.config.get("zmq_port", 5557) + hnsw_config.is_skip_neighbors = kwargs.get("is_skip_neighbors", False) + hnsw_config.is_recompute = self.is_pruned or kwargs.get("is_recompute", False) + hnsw_config.disk_cache_ratio = kwargs.get("disk_cache_ratio", 0.0) + hnsw_config.external_storage_path = kwargs.get("external_storage_path") + hnsw_config.zmq_port = kwargs.get("zmq_port", 5557) if self.is_pruned and not hnsw_config.is_recompute: raise RuntimeError("Index is pruned (embeddings removed) but recompute is disabled. This is impossible - recompute must be enabled for pruned indices.") @@ -431,82 +318,55 @@ class HNSWSearcher(LeannBackendSearcherInterface): else: print("✅ Standard HNSW index loaded successfully.") - self.metric_str = metric_str - self.embedding_server_manager = HNSWEmbeddingServerManager() - - def _get_index_file(self, index_dir: Path, index_prefix: str) -> Path: - """Get the appropriate index file path based on format""" - # We always use the same filename now, format is detected internally - return index_dir / f"{index_prefix}.index" + self.embedding_server_manager = EmbeddingServerManager( + backend_module_name="leann_backend_hnsw.hnsw_embedding_server" + ) def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, Any]: """Search using HNSW index with optional recompute functionality""" from . import faiss - # Merge config with search-time kwargs - search_config = self.config.copy() - search_config.update(kwargs) - ef = search_config.get("ef", 200) # Size of the dynamic candidate list for search + ef = kwargs.get("ef", 200) - # Recompute parameters - zmq_port = search_config.get("zmq_port", 5557) - embedding_model = search_config.get("embedding_model", "sentence-transformers/all-mpnet-base-v2") - passages_file = search_config.get("passages_file", None) - - # For recompute mode, try to find the passages file automatically - if self.is_pruned and not passages_file: - potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json" - print(f"DEBUG: Checking for passages file at: {potential_passages_file}") - if potential_passages_file.exists(): - passages_file = str(potential_passages_file) - print(f"INFO: Found passages file for recompute mode: {passages_file}") - else: - print(f"WARNING: No passages file found for recompute mode at {potential_passages_file}") - - # If index is pruned (embeddings removed), we MUST start embedding server for recompute if self.is_pruned: - print(f"INFO: Index is pruned - starting embedding server for recompute") - - # CRITICAL: Check passages file exists - fail fast if not + print(f"INFO: Index is pruned - ensuring embedding server is running for recompute.") + if not self.embedding_model: + raise ValueError("Cannot use recompute mode without 'embedding_model' in meta.json.") + + passages_file = kwargs.get("passages_file") if not passages_file: - raise RuntimeError(f"FATAL: Index is pruned but no passages file found. Cannot proceed with recompute mode.") - - # Check if server is already running first - if _check_port(zmq_port): - print(f"INFO: Embedding server already running on port {zmq_port}") - else: - if not self.embedding_server_manager.start_server(zmq_port, embedding_model, passages_file, self.metric_str): - raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}") - - # Give server extra time to fully initialize - print(f"INFO: Waiting for embedding server to fully initialize...") - time.sleep(3) - - # Final verification - if not _check_port(zmq_port): - raise RuntimeError(f"Embedding server failed to start listening on port {zmq_port}") - else: - print(f"INFO: Index has embeddings stored - no recompute needed") + potential_passages_file = self.index_dir / f"{self.index_prefix}.passages.json" + if potential_passages_file.exists(): + passages_file = str(potential_passages_file) + print(f"INFO: Automatically found passages file: {passages_file}") + else: + raise RuntimeError(f"FATAL: Index is pruned but no passages file found.") + + zmq_port = kwargs.get("zmq_port", 5557) + server_started = self.embedding_server_manager.start_server( + port=zmq_port, + model_name=self.embedding_model, + passages_file=passages_file, + distance_metric=self.distance_metric + ) + if not server_started: + raise RuntimeError(f"Failed to start HNSW embedding server on port {zmq_port}") if query.dtype != np.float32: query = query.astype(np.float32) if query.ndim == 1: query = np.expand_dims(query, axis=0) - # Normalize query if using cosine similarity - if self.metric_str == "cosine": + if self.distance_metric == "cosine": faiss.normalize_L2(query) try: - # Set search parameter self._index.hnsw.efSearch = ef - # Prepare output arrays for the older FAISS SWIG API batch_size = query.shape[0] distances = np.empty((batch_size, top_k), dtype=np.float32) labels = np.empty((batch_size, top_k), dtype=np.int64) - # Use standard FAISS search - recompute is handled internally by FAISS self._index.search(query.shape[0], faiss.swig_ptr(query), top_k, faiss.swig_ptr(distances), faiss.swig_ptr(labels)) return {"labels": labels, "distances": distances} @@ -517,4 +377,4 @@ class HNSWSearcher(LeannBackendSearcherInterface): def __del__(self): if hasattr(self, 'embedding_server_manager'): - self.embedding_server_manager.stop_server() \ No newline at end of file + self.embedding_server_manager.stop_server() diff --git a/packages/leann-core/src/leann/__init__.py b/packages/leann-core/src/leann/__init__.py index 251953e..e69de29 100644 --- a/packages/leann-core/src/leann/__init__.py +++ b/packages/leann-core/src/leann/__init__.py @@ -1,17 +0,0 @@ -# This file makes the 'leann' directory a Python package. - -from .api import LeannBuilder, LeannSearcher, LeannChat, SearchResult - -# Import backends to ensure they are registered -try: - import leann_backend_hnsw -except ImportError: - pass - -try: - import leann_backend_diskann -except ImportError: - pass - - -__all__ = ['LeannBuilder', 'LeannSearcher', 'LeannChat', 'SearchResult'] diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py new file mode 100644 index 0000000..ef2fd4d --- /dev/null +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -0,0 +1,132 @@ + +import os +import threading +import time +import atexit +import socket +import subprocess +import sys +from pathlib import Path +from typing import Optional + +def _check_port(port: int) -> bool: + """Check if a port is in use""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', port)) == 0 + +class EmbeddingServerManager: + """ + A generic manager for handling the lifecycle of a backend-specific embedding server process. + """ + def __init__(self, backend_module_name: str): + """ + Initializes the manager for a specific backend. + + Args: + backend_module_name (str): The full module name of the backend's server script. + e.g., "leann_backend_diskann.embedding_server" + """ + self.backend_module_name = backend_module_name + self.server_process: Optional[subprocess.Popen] = None + self.server_port: Optional[int] = None + atexit.register(self.stop_server) + + def start_server(self, port: int, model_name: str, **kwargs) -> bool: + """ + Starts the embedding server process. + + Args: + port (int): The ZMQ port for the server. + model_name (str): The name of the embedding model to use. + **kwargs: Additional arguments for the server (e.g., passages_file, distance_metric). + + Returns: + bool: True if the server is started successfully or already running, False otherwise. + """ + if self.server_process and self.server_process.poll() is None: + print(f"INFO: Reusing existing server process for this session (PID {self.server_process.pid})") + return True + + if _check_port(port): + print(f"WARNING: Port {port} is already in use. Assuming an external server is running.") + return True + + print(f"INFO: Starting session-level embedding server for '{self.backend_module_name}'...") + + try: + command = [ + sys.executable, + "-m", self.backend_module_name, + "--zmq-port", str(port), + "--model-name", model_name + ] + + # Add extra arguments for specific backends + if "passages_file" in kwargs and kwargs["passages_file"]: + command.extend(["--passages-file", str(kwargs["passages_file"])]) + # if "distance_metric" in kwargs and kwargs["distance_metric"]: + # command.extend(["--distance-metric", kwargs["distance_metric"]]) + + project_root = Path(__file__).parent.parent.parent.parent.parent + print(f"INFO: Running command from project root: {project_root}") + + self.server_process = subprocess.Popen( + command, + cwd=project_root, + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + text=True, + encoding='utf-8' + ) + self.server_port = port + print(f"INFO: Server process started with PID: {self.server_process.pid}") + + max_wait, wait_interval = 30, 0.5 + for _ in range(int(max_wait / wait_interval)): + if _check_port(port): + print(f"✅ Embedding server is up and ready for this session.") + log_thread = threading.Thread(target=self._log_monitor, daemon=True) + log_thread.start() + return True + if self.server_process.poll() is not None: + print("❌ ERROR: Server process terminated unexpectedly during startup.") + self._log_monitor() + return False + time.sleep(wait_interval) + + print(f"❌ ERROR: Server process failed to start listening within {max_wait} seconds.") + self.stop_server() + return False + + except Exception as e: + print(f"❌ ERROR: Failed to start embedding server process: {e}") + return False + + def _log_monitor(self): + """Monitors and prints the server's stdout and stderr.""" + if not self.server_process: + return + try: + if self.server_process.stdout: + for line in iter(self.server_process.stdout.readline, ''): + print(f"[{self.backend_module_name} LOG]: {line.strip()}") + self.server_process.stdout.close() + if self.server_process.stderr: + for line in iter(self.server_process.stderr.readline, ''): + print(f"[{self.backend_module_name} ERROR]: {line.strip()}") + self.server_process.stderr.close() + except Exception as e: + print(f"Log monitor error: {e}") + + def stop_server(self): + """Stops the embedding server process if it's running.""" + if self.server_process and self.server_process.poll() is None: + print(f"INFO: Terminating session server process (PID: {self.server_process.pid})...") + self.server_process.terminate() + try: + self.server_process.wait(timeout=5) + print("INFO: Server process terminated.") + except subprocess.TimeoutExpired: + print("WARNING: Server process did not terminate gracefully, killing it.") + self.server_process.kill() + self.server_process = None