diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index 832093d..310fc85 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -96,7 +96,7 @@ def compute_embeddings_sentence_transformers( backend_module_name="leann_backend_hnsw.hnsw_embedding_server" ) - server_started = server_manager.start_server( + server_started, actual_port = server_manager.start_server( port=port, model_name=model_name, embedding_mode="sentence-transformers", @@ -104,7 +104,10 @@ def compute_embeddings_sentence_transformers( ) if not server_started: - raise RuntimeError(f"Failed to start embedding server on port {port}") + raise RuntimeError(f"Failed to start embedding server on port {actual_port}") + + # Use the actual port for connection + port = actual_port # Connect to embedding server context = zmq.Context() diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 2022262..4071dfc 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -9,6 +9,7 @@ import msgpack from pathlib import Path from typing import Optional import select +import psutil def _check_port(port: int) -> bool: @@ -17,151 +18,131 @@ def _check_port(port: int) -> bool: return s.connect_ex(("localhost", port)) == 0 -def _check_server_meta_path(port: int, expected_meta_path: str) -> bool: +def _check_process_matches_config( + port: int, expected_model: str, expected_passages_file: str = None +) -> bool: """ - Check if the existing server on the port is using the correct meta file. - Returns True if the server has the right meta path, False otherwise. + Check if the process using the port matches our expected model and passages file. + Returns True if matches, False otherwise. """ try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout - socket.connect(f"tcp://localhost:{port}") - - # Send a special control message to query the server's meta path - control_request = ["__QUERY_META_PATH__"] - request_bytes = msgpack.packb(control_request) - socket.send(request_bytes) - - # Wait for response - response_bytes = socket.recv() - response = msgpack.unpackb(response_bytes) - - socket.close() - context.term() - - # Check if the response contains the meta path and if it matches - if isinstance(response, list) and len(response) > 0: - server_meta_path = response[0] - # Normalize paths for comparison - expected_path = Path(expected_meta_path).resolve() - server_path = Path(server_meta_path).resolve() if server_meta_path else None - return server_path == expected_path - + for proc in psutil.process_iter(["pid", "cmdline"]): + if not _is_process_listening_on_port(proc, port): + continue + + cmdline = proc.info["cmdline"] + if not cmdline: + continue + + return _check_cmdline_matches_config(cmdline, port, expected_model, expected_passages_file) + + print(f"DEBUG: No process found listening on port {port}") return False - + except Exception as e: - print(f"WARNING: Could not query server meta path on port {port}: {e}") + print(f"WARNING: Could not check process on port {port}: {e}") return False -def _update_server_meta_path(port: int, new_meta_path: str) -> bool: - """ - Send a control message to update the server's meta path. - Returns True if successful, False otherwise. - """ +def _is_process_listening_on_port(proc, port: int) -> bool: + """Check if a process is listening on the given port.""" try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 5000) # 5 second timeout - socket.connect(f"tcp://localhost:{port}") - - # Send a control message to update the meta path - control_request = ["__UPDATE_META_PATH__", new_meta_path] - request_bytes = msgpack.packb(control_request) - socket.send(request_bytes) - - # Wait for response - response_bytes = socket.recv() - response = msgpack.unpackb(response_bytes) - - socket.close() - context.term() - - # Check if the update was successful - if isinstance(response, list) and len(response) > 0: - return response[0] == "SUCCESS" - + connections = proc.net_connections() + for conn in connections: + if conn.laddr.port == port and conn.status == psutil.CONN_LISTEN: + return True return False - - except Exception as e: - print(f"ERROR: Could not update server meta path on port {port}: {e}") + except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): return False -def _check_server_model(port: int, expected_model: str) -> bool: +def _check_cmdline_matches_config( + cmdline: list, port: int, expected_model: str, expected_passages_file: str = None +) -> bool: + """Check if command line matches our expected configuration.""" + cmdline_str = " ".join(cmdline) + print(f"DEBUG: Found process on port {port}: {cmdline_str}") + + # Check if it's our embedding server + is_embedding_server = any(server_type in cmdline_str for server_type in [ + "embedding_server", + "leann_backend_diskann.embedding_server", + "leann_backend_hnsw.hnsw_embedding_server" + ]) + + if not is_embedding_server: + print(f"DEBUG: Process on port {port} is not our embedding server") + return False + + # Check model name + model_matches = _check_model_in_cmdline(cmdline, expected_model) + + # Check passages file if provided + passages_matches = _check_passages_in_cmdline(cmdline, expected_passages_file) + + result = model_matches and passages_matches + print(f"DEBUG: model_matches: {model_matches}, passages_matches: {passages_matches}, overall: {result}") + return result + + +def _check_model_in_cmdline(cmdline: list, expected_model: str) -> bool: + """Check if the command line contains the expected model.""" + if "--model-name" not in cmdline: + return False + + model_idx = cmdline.index("--model-name") + if model_idx + 1 >= len(cmdline): + return False + + actual_model = cmdline[model_idx + 1] + return actual_model == expected_model + + +def _check_passages_in_cmdline(cmdline: list, expected_passages_file: str = None) -> bool: + """Check if the command line contains the expected passages file.""" + if not expected_passages_file: + return True # No passages file expected + + if "--passages-file" not in cmdline: + return False # Expected but not found + + passages_idx = cmdline.index("--passages-file") + if passages_idx + 1 >= len(cmdline): + return False + + actual_passages = cmdline[passages_idx + 1] + expected_path = Path(expected_passages_file).resolve() + actual_path = Path(actual_passages).resolve() + return actual_path == expected_path + + +def _find_compatible_port_or_next_available( + start_port: int, model_name: str, passages_file: str = None, max_attempts: int = 100 +) -> tuple[int, bool]: """ - Check if the existing server on the port is using the correct embedding model. - Returns True if the server has the right model, False otherwise. + Find a port that either has a compatible server or is available. + Returns (port, is_compatible) where is_compatible indicates if we found a matching server. """ - try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 3000) # 3 second timeout - socket.connect(f"tcp://localhost:{port}") + for port in range(start_port, start_port + max_attempts): + if not _check_port(port): + # Port is available + return port, False - # Send a special control message to query the server's model - control_request = ["__QUERY_MODEL__"] - request_bytes = msgpack.packb(control_request) - socket.send(request_bytes) + # Port is in use, check if it's compatible + if _check_process_matches_config(port, model_name, passages_file): + print(f"✅ Found compatible server on port {port}") + return port, True + else: + print(f"⚠️ Port {port} has incompatible server, trying next port...") - # Wait for response - response_bytes = socket.recv() - response = msgpack.unpackb(response_bytes) - - socket.close() - context.term() - - # Check if the response contains the model name and if it matches - if isinstance(response, list) and len(response) > 0: - server_model = response[0] - return server_model == expected_model - - return False - - except Exception as e: - print(f"WARNING: Could not query server model on port {port}: {e}") - return False - - -def _update_server_model(port: int, new_model: str) -> bool: - """ - Send a control message to update the server's embedding model. - Returns True if successful, False otherwise. - """ - try: - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.setsockopt(zmq.RCVTIMEO, 30000) # 30 second timeout for model loading - socket.setsockopt(zmq.SNDTIMEO, 5000) # 5 second timeout for sending - socket.connect(f"tcp://localhost:{port}") - - # Send a control message to update the model - control_request = ["__UPDATE_MODEL__", new_model] - request_bytes = msgpack.packb(control_request) - socket.send(request_bytes) - - # Wait for response - response_bytes = socket.recv() - response = msgpack.unpackb(response_bytes) - - socket.close() - context.term() - - # Check if the update was successful - if isinstance(response, list) and len(response) > 0: - return response[0] == "SUCCESS" - - return False - - except Exception as e: - print(f"ERROR: Could not update server model on port {port}: {e}") - return False + raise RuntimeError( + f"Could not find compatible or available port in range {start_port}-{start_port + max_attempts}" + ) class EmbeddingServerManager: """ - A generic manager for handling the lifecycle of a backend-specific embedding server process. + A simplified manager for embedding server processes that avoids complex update mechanisms. """ def __init__(self, backend_module_name: str): @@ -177,208 +158,132 @@ class EmbeddingServerManager: self.server_port: Optional[int] = None atexit.register(self.stop_server) - def start_server(self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", **kwargs) -> bool: + def start_server( + self, + port: int, + model_name: str, + embedding_mode: str = "sentence-transformers", + **kwargs, + ) -> tuple[bool, int]: """ Starts the embedding server process. Args: - port (int): The ZMQ port for the server. + port (int): The preferred ZMQ port for the server. model_name (str): The name of the embedding model to use. - **kwargs: Additional arguments for the server (e.g., passages_file, distance_metric, enable_warmup). + **kwargs: Additional arguments for the server. Returns: - bool: True if the server is started successfully or already running, False otherwise. + tuple[bool, int]: (success, actual_port_used) """ - if self.server_process and self.server_process.poll() is None: - # Even if we have a running process, check if model/meta path match - if self.server_port is not None: - port_in_use = _check_port(self.server_port) - if port_in_use: - print( - f"INFO: Checking compatibility of existing server process (PID {self.server_process.pid})" - ) + passages_file = kwargs.get("passages_file") - # Check model compatibility - model_matches = _check_server_model(self.server_port, model_name) - if model_matches: - print( - f"✅ Existing server already using correct model: {model_name}" - ) - - # Still check meta path if provided - passages_file = kwargs.get("passages_file") - if passages_file and str(passages_file).endswith( - ".meta.json" - ): - meta_matches = _check_server_meta_path( - self.server_port, str(passages_file) - ) - if not meta_matches: - print("⚠️ Updating meta path to: {passages_file}") - _update_server_meta_path( - self.server_port, str(passages_file) - ) - - return True - else: - print( - f"⚠️ Existing server has different model. Attempting to update to: {model_name}" - ) - if not _update_server_model(self.server_port, model_name): - print( - "❌ Failed to update existing server model. Restarting server..." - ) - self.stop_server() - # Continue to start new server below - else: - print( - f"✅ Successfully updated existing server model to: {model_name}" - ) - - # Also check meta path if provided - passages_file = kwargs.get("passages_file") - if passages_file and str(passages_file).endswith( - ".meta.json" - ): - meta_matches = _check_server_meta_path( - self.server_port, str(passages_file) - ) - if not meta_matches: - print("⚠️ Updating meta path to: {passages_file}") - _update_server_meta_path( - self.server_port, str(passages_file) - ) - - return True - else: - # Server process exists but port not responding - restart - print("⚠️ Server process exists but not responding. Restarting...") - self.stop_server() - # Continue to start new server below - else: - # No port stored - restart - print("⚠️ No port information stored. Restarting server...") - self.stop_server() - # Continue to start new server below - - if _check_port(port): - # Port is in use, check if it's using the correct meta file and model - passages_file = kwargs.get("passages_file") - - print(f"INFO: Port {port} is in use. Checking server compatibility...") - - # Check model compatibility first - model_matches = _check_server_model(port, model_name) - if model_matches: - print( - f"✅ Existing server on port {port} is using correct model: {model_name}" - ) - else: - print( - f"⚠️ Existing server on port {port} has different model. Attempting to update to: {model_name}" - ) - if not _update_server_model(port, model_name): - raise RuntimeError( - f"❌ Failed to update server model to {model_name}. Consider using a different port." - ) - print(f"✅ Successfully updated server model to: {model_name}") - - # Check meta path compatibility if provided - if passages_file and str(passages_file).endswith(".meta.json"): - meta_matches = _check_server_meta_path(port, str(passages_file)) - if not meta_matches: - print( - f"⚠️ Existing server on port {port} has different meta path. Attempting to update..." - ) - if not _update_server_meta_path(port, str(passages_file)): - raise RuntimeError( - "❌ Failed to update server meta path. This may cause data synchronization issues." - ) - print( - f"✅ Successfully updated server meta path to: {passages_file}" - ) - else: - print( - f"✅ Existing server on port {port} is using correct meta path: {passages_file}" - ) - - print(f"✅ Server on port {port} is compatible and ready to use.") - return True - - print( - f"INFO: Starting session-level embedding server for '{self.backend_module_name}'..." - ) + # Check if we have a compatible running server + if self._has_compatible_running_server(model_name, passages_file): + return True, self.server_port + # Find available port (compatible or free) try: - command = [ - sys.executable, - "-m", - self.backend_module_name, - "--zmq-port", - str(port), - "--model-name", - model_name, - ] - - # Add extra arguments for specific backends - if "passages_file" in kwargs and kwargs["passages_file"]: - command.extend(["--passages-file", str(kwargs["passages_file"])]) - # if "distance_metric" in kwargs and kwargs["distance_metric"]: - # command.extend(["--distance-metric", kwargs["distance_metric"]]) - if embedding_mode != "sentence-transformers": - command.extend(["--embedding-mode", embedding_mode]) - if "enable_warmup" in kwargs and not kwargs["enable_warmup"]: - command.extend(["--disable-warmup"]) - - project_root = Path(__file__).parent.parent.parent.parent.parent - print(f"INFO: Running command from project root: {project_root}") - print(f"INFO: Command: {' '.join(command)}") # Debug: show actual command - - self.server_process = subprocess.Popen( - command, - cwd=project_root, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, # Merge stderr into stdout for easier monitoring - text=True, - encoding="utf-8", - bufsize=1, # Line buffered - universal_newlines=True, + actual_port, is_compatible = _find_compatible_port_or_next_available( + port, model_name, passages_file ) - self.server_port = port - print(f"INFO: Server process started with PID: {self.server_process.pid}") + except RuntimeError as e: + print(f"❌ {e}") + return False, port - max_wait, wait_interval = 120, 0.5 - for _ in range(int(max_wait / wait_interval)): - if _check_port(port): - print("✅ Embedding server is up and ready for this session.") - log_thread = threading.Thread(target=self._log_monitor, daemon=True) - log_thread.start() - return True - if self.server_process.poll() is not None: - print( - "❌ ERROR: Server process terminated unexpectedly during startup." - ) - self._print_recent_output() - return False - time.sleep(wait_interval) + if is_compatible: + print(f"✅ Using existing compatible server on port {actual_port}") + self.server_port = actual_port + self.server_process = None # We don't own this process + return True, actual_port - print( - f"❌ ERROR: Server process failed to start listening within {max_wait} seconds." - ) - self.stop_server() + if actual_port != port: + print(f"⚠️ Using port {actual_port} instead of {port}") + + # Start new server + return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) + + def _has_compatible_running_server(self, model_name: str, passages_file: str) -> bool: + """Check if we have a compatible running server.""" + if not (self.server_process and self.server_process.poll() is None and self.server_port): return False + if _check_process_matches_config(self.server_port, model_name, passages_file): + print(f"✅ Existing server process (PID {self.server_process.pid}) is compatible") + return True + + print("⚠️ Existing server process is incompatible. Stopping it...") + self.stop_server() + return False + + def _start_new_server(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> tuple[bool, int]: + """Start a new embedding server on the given port.""" + print(f"INFO: Starting embedding server on port {port}...") + + command = self._build_server_command(port, model_name, embedding_mode, **kwargs) + + try: + self._launch_server_process(command, port) + return self._wait_for_server_ready(port) except Exception as e: - print(f"❌ ERROR: Failed to start embedding server process: {e}") - return False + print(f"❌ ERROR: Failed to start embedding server: {e}") + return False, port + + def _build_server_command(self, port: int, model_name: str, embedding_mode: str, **kwargs) -> list: + """Build the command to start the embedding server.""" + command = [ + sys.executable, "-m", self.backend_module_name, + "--zmq-port", str(port), + "--model-name", model_name, + ] + + if kwargs.get("passages_file"): + command.extend(["--passages-file", str(kwargs["passages_file"])]) + if embedding_mode != "sentence-transformers": + command.extend(["--embedding-mode", embedding_mode]) + if kwargs.get("enable_warmup") is False: + command.extend(["--disable-warmup"]) + + return command + + def _launch_server_process(self, command: list, port: int) -> None: + """Launch the server process.""" + project_root = Path(__file__).parent.parent.parent.parent.parent + print(f"INFO: Command: {' '.join(command)}") + + self.server_process = subprocess.Popen( + command, cwd=project_root, + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, encoding="utf-8", bufsize=1, universal_newlines=True, + ) + self.server_port = port + print(f"INFO: Server process started with PID: {self.server_process.pid}") + + def _wait_for_server_ready(self, port: int) -> tuple[bool, int]: + """Wait for the server to be ready.""" + max_wait, wait_interval = 120, 0.5 + for _ in range(int(max_wait / wait_interval)): + if _check_port(port): + print("✅ Embedding server is ready!") + threading.Thread(target=self._log_monitor, daemon=True).start() + return True, port + + if self.server_process.poll() is not None: + print("❌ ERROR: Server terminated during startup.") + self._print_recent_output() + return False, port + + time.sleep(wait_interval) + + print(f"❌ ERROR: Server failed to start within {max_wait} seconds.") + self.stop_server() + return False, port def _print_recent_output(self): """Print any recent output from the server process.""" if not self.server_process or not self.server_process.stdout: return try: - # Read any available output - if select.select([self.server_process.stdout], [], [], 0)[0]: output = self.server_process.stdout.read() if output: diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index 0f40a85..7792af0 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -80,7 +80,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") - server_started = self.embedding_server_manager.start_server( + server_started, actual_port = self.embedding_server_manager.start_server( port=port, model_name=self.embedding_model, passages_file=passages_source_file, @@ -89,7 +89,11 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC): enable_warmup=kwargs.get("enable_warmup", False), ) if not server_started: - raise RuntimeError(f"Failed to start embedding server on port {port}") + raise RuntimeError(f"Failed to start embedding server on port {actual_port}") + + # Update the port information for future use + if hasattr(self, '_actual_server_port'): + self._actual_server_port = actual_port def compute_query_embedding( self, query: str, zmq_port: int = 5557, use_server_if_available: bool = True diff --git a/pyproject.toml b/pyproject.toml index 3a0c027..b42dcbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "llama-index-embeddings-huggingface>=0.5.5", "mlx>=0.26.3", "mlx-lm>=0.26.0", + "psutil>=5.8.0", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index aa661f6..6c84ad2 100644 --- a/uv.lock +++ b/uv.lock @@ -1834,10 +1834,14 @@ source = { editable = "packages/leann-core" } dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "tqdm" }, ] [package.metadata] -requires-dist = [{ name = "numpy", specifier = ">=1.20.0" }] +requires-dist = [ + { name = "numpy", specifier = ">=1.20.0" }, + { name = "tqdm", specifier = ">=4.60.0" }, +] [[package]] name = "leann-workspace" @@ -1851,7 +1855,6 @@ dependencies = [ { name = "flask" }, { name = "flask-compress" }, { name = "ipykernel" }, - { name = "leann-backend-diskann" }, { name = "leann-backend-hnsw" }, { name = "leann-core" }, { name = "llama-index" }, @@ -1867,6 +1870,7 @@ dependencies = [ { name = "ollama" }, { name = "openai" }, { name = "protobuf" }, + { name = "psutil" }, { name = "pypdf2" }, { name = "requests" }, { name = "sentence-transformers" }, @@ -1884,6 +1888,9 @@ dev = [ { name = "pytest-cov" }, { name = "ruff" }, ] +diskann = [ + { name = "leann-backend-diskann" }, +] [package.metadata] requires-dist = [ @@ -1896,7 +1903,7 @@ requires-dist = [ { name = "flask-compress" }, { name = "huggingface-hub", marker = "extra == 'dev'", specifier = ">=0.20.0" }, { name = "ipykernel", specifier = "==6.29.5" }, - { name = "leann-backend-diskann", editable = "packages/leann-backend-diskann" }, + { name = "leann-backend-diskann", marker = "extra == 'diskann'", editable = "packages/leann-backend-diskann" }, { name = "leann-backend-hnsw", editable = "packages/leann-backend-hnsw" }, { name = "leann-core", editable = "packages/leann-core" }, { name = "llama-index", specifier = ">=0.12.44" }, @@ -1912,6 +1919,7 @@ requires-dist = [ { name = "ollama" }, { name = "openai", specifier = ">=1.0.0" }, { name = "protobuf", specifier = "==4.25.3" }, + { name = "psutil", specifier = ">=5.8.0" }, { name = "pypdf2", specifier = ">=3.0.0" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0" }, @@ -1922,7 +1930,7 @@ requires-dist = [ { name = "torch" }, { name = "tqdm" }, ] -provides-extras = ["dev"] +provides-extras = ["dev", "diskann"] [[package]] name = "llama-cloud"