diff --git a/packages/leann-core/src/leann/embedding_server_manager.py b/packages/leann-core/src/leann/embedding_server_manager.py index 7a2fcb2..ca61d05 100644 --- a/packages/leann-core/src/leann/embedding_server_manager.py +++ b/packages/leann-core/src/leann/embedding_server_manager.py @@ -1,4 +1,5 @@ import atexit +import json import logging import os import socket @@ -48,6 +49,85 @@ def _check_port(port: int) -> bool: # Note: All cross-process scanning helpers removed for simplicity +def _safe_resolve(path: Path) -> str: + """Resolve paths safely even if the target does not yet exist.""" + try: + return str(path.resolve(strict=False)) + except Exception: + return str(path) + + +def _safe_stat_signature(path: Path) -> dict: + """Return a lightweight signature describing the current state of a path.""" + signature: dict[str, object] = {"path": _safe_resolve(path)} + try: + stat = path.stat() + except FileNotFoundError: + signature["missing"] = True + except Exception as exc: # pragma: no cover - unexpected filesystem errors + signature["error"] = str(exc) + else: + signature["mtime_ns"] = stat.st_mtime_ns + signature["size"] = stat.st_size + return signature + + +def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]: + """Collect modification signatures for metadata and referenced passage files.""" + if not passages_file: + return None + + meta_path = Path(passages_file) + signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)} + + try: + with meta_path.open(encoding="utf-8") as fh: + meta = json.load(fh) + except FileNotFoundError: + signature["meta_missing"] = True + signature["sources"] = [] + return signature + except json.JSONDecodeError as exc: + signature["meta_error"] = f"json_error:{exc}" + signature["sources"] = [] + return signature + except Exception as exc: # pragma: no cover - unexpected errors + signature["meta_error"] = str(exc) + signature["sources"] = [] + return signature + + base_dir = meta_path.parent + seen_paths: set[str] = set() + source_signatures: list[dict[str, object]] = [] + + for source in meta.get("passage_sources", []): + for key, kind in ( + ("path", "passages"), + ("path_relative", "passages"), + ("index_path", "index"), + ("index_path_relative", "index"), + ): + raw_path = source.get(key) + if not raw_path: + continue + candidate = Path(raw_path) + if not candidate.is_absolute(): + candidate = base_dir / candidate + resolved = _safe_resolve(candidate) + if resolved in seen_paths: + continue + seen_paths.add(resolved) + sig = _safe_stat_signature(candidate) + sig["kind"] = kind + source_signatures.append(sig) + + signature["sources"] = source_signatures + return signature + + +# Note: All cross-process scanning helpers removed for simplicity + + class EmbeddingServerManager: """ A simplified manager for embedding server processes that avoids complex update mechanisms. @@ -85,13 +165,14 @@ class EmbeddingServerManager: """Start the embedding server.""" # passages_file may be present in kwargs for server CLI, but we don't need it here provider_options = kwargs.pop("provider_options", None) + passages_file = kwargs.get("passages_file", "") - config_signature = { - "model_name": model_name, - "passages_file": kwargs.get("passages_file", ""), - "embedding_mode": embedding_mode, - "provider_options": provider_options or {}, - } + config_signature = self._build_config_signature( + model_name=model_name, + embedding_mode=embedding_mode, + provider_options=provider_options, + passages_file=passages_file, + ) # If this manager already has a live server, just reuse it if ( @@ -115,6 +196,7 @@ class EmbeddingServerManager: port, model_name, embedding_mode, + config_signature=config_signature, provider_options=provider_options, **kwargs, ) @@ -136,11 +218,30 @@ class EmbeddingServerManager: **kwargs, ) + def _build_config_signature( + self, + *, + model_name: str, + embedding_mode: str, + provider_options: Optional[dict], + passages_file: Optional[str], + ) -> dict: + """Create a signature describing the current server configuration.""" + return { + "model_name": model_name, + "passages_file": passages_file or "", + "embedding_mode": embedding_mode, + "provider_options": provider_options or {}, + "passages_signature": _build_passages_signature(passages_file), + } + def _start_server_colab( self, port: int, model_name: str, embedding_mode: str = "sentence-transformers", + *, + config_signature: Optional[dict] = None, provider_options: Optional[dict] = None, **kwargs, ) -> tuple[bool, int]: @@ -163,10 +264,11 @@ class EmbeddingServerManager: command, actual_port, provider_options=provider_options, + config_signature=config_signature, ) started, ready_port = self._wait_for_server_ready_colab(actual_port) if started: - self._server_config = { + self._server_config = config_signature or { "model_name": model_name, "passages_file": kwargs.get("passages_file", ""), "embedding_mode": embedding_mode, @@ -198,6 +300,7 @@ class EmbeddingServerManager: command, port, provider_options=provider_options, + config_signature=config_signature, ) started, ready_port = self._wait_for_server_ready(port) if started: @@ -241,7 +344,9 @@ class EmbeddingServerManager: self, command: list, port: int, + *, provider_options: Optional[dict] = None, + config_signature: Optional[dict] = None, ) -> None: """Launch the server process.""" project_root = Path(__file__).parent.parent.parent.parent.parent @@ -276,26 +381,29 @@ class EmbeddingServerManager: ) self.server_port = port # Record config for in-process reuse (best effort; refined later when ready) - try: - self._server_config = { - "model_name": command[command.index("--model-name") + 1] - if "--model-name" in command - else "", - "passages_file": command[command.index("--passages-file") + 1] - if "--passages-file" in command - else "", - "embedding_mode": command[command.index("--embedding-mode") + 1] - if "--embedding-mode" in command - else "sentence-transformers", - "provider_options": provider_options or {}, - } - except Exception: - self._server_config = { - "model_name": "", - "passages_file": "", - "embedding_mode": "sentence-transformers", - "provider_options": provider_options or {}, - } + if config_signature is not None: + self._server_config = config_signature + else: # Fallback for unexpected code paths + try: + self._server_config = { + "model_name": command[command.index("--model-name") + 1] + if "--model-name" in command + else "", + "passages_file": command[command.index("--passages-file") + 1] + if "--passages-file" in command + else "", + "embedding_mode": command[command.index("--embedding-mode") + 1] + if "--embedding-mode" in command + else "sentence-transformers", + "provider_options": provider_options or {}, + } + except Exception: + self._server_config = { + "model_name": "", + "passages_file": "", + "embedding_mode": "sentence-transformers", + "provider_options": provider_options or {}, + } logger.info(f"Server process started with PID: {self.server_process.pid}") # Register atexit callback only when we actually start a process @@ -403,7 +511,9 @@ class EmbeddingServerManager: self, command: list, port: int, + *, provider_options: Optional[dict] = None, + config_signature: Optional[dict] = None, ) -> None: """Launch the server process with Colab-specific settings.""" logger.info(f"Colab Command: {' '.join(command)}") @@ -429,12 +539,15 @@ class EmbeddingServerManager: atexit.register(self._finalize_process) self._atexit_registered = True # Record config for in-process reuse is best-effort in Colab mode - self._server_config = { - "model_name": "", - "passages_file": "", - "embedding_mode": "sentence-transformers", - "provider_options": provider_options or {}, - } + if config_signature is not None: + self._server_config = config_signature + else: + self._server_config = { + "model_name": "", + "passages_file": "", + "embedding_mode": "sentence-transformers", + "provider_options": provider_options or {}, + } def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: """Wait for the server to be ready with Colab-specific timeout.""" diff --git a/tests/test_embedding_server_manager.py b/tests/test_embedding_server_manager.py new file mode 100644 index 0000000..71e91f3 --- /dev/null +++ b/tests/test_embedding_server_manager.py @@ -0,0 +1,137 @@ +import json +import time + +import pytest +from leann.embedding_server_manager import EmbeddingServerManager + + +class DummyProcess: + def __init__(self): + self.pid = 12345 + self._terminated = False + + def poll(self): + return 0 if self._terminated else None + + def terminate(self): + self._terminated = True + + def kill(self): + self._terminated = True + + def wait(self, timeout=None): + self._terminated = True + return 0 + + +@pytest.fixture +def embedding_manager(monkeypatch): + manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server") + + def fake_get_available_port(start_port): + return start_port + + monkeypatch.setattr( + "leann.embedding_server_manager._get_available_port", + fake_get_available_port, + ) + + start_calls = [] + + def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs): + config_signature = kwargs.get("config_signature") + start_calls.append(config_signature) + self.server_process = DummyProcess() + self.server_port = port + self._server_config = config_signature + return True, port + + monkeypatch.setattr( + EmbeddingServerManager, + "_start_new_server", + fake_start_new_server, + ) + + # Ensure stop_server doesn't try to operate on real subprocesses + def fake_stop_server(self): + self.server_process = None + self.server_port = None + self._server_config = None + + monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server) + + return manager, start_calls + + +def _write_meta(meta_path, passages_name, index_name, total): + meta_path.write_text( + json.dumps( + { + "backend_name": "hnsw", + "embedding_model": "test-model", + "embedding_mode": "sentence-transformers", + "dimensions": 3, + "backend_kwargs": {}, + "passage_sources": [ + { + "type": "jsonl", + "path": passages_name, + "index_path": index_name, + } + ], + "total_passages": total, + } + ), + encoding="utf-8", + ) + + +def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager): + manager, start_calls = embedding_manager + + meta_path = tmp_path / "example.meta.json" + passages_path = tmp_path / "example.passages.jsonl" + index_path = tmp_path / "example.passages.idx" + + passages_path.write_text("first\n", encoding="utf-8") + index_path.write_bytes(b"index") + _write_meta(meta_path, passages_path.name, index_path.name, total=1) + + # Initial start populates signature + ok, port = manager.start_server( + port=6000, + model_name="test-model", + passages_file=str(meta_path), + ) + assert ok + assert port == 6000 + assert len(start_calls) == 1 + + initial_signature = start_calls[0]["passages_signature"] + + # No metadata change => reuse existing server + ok, port_again = manager.start_server( + port=6000, + model_name="test-model", + passages_file=str(meta_path), + ) + assert ok + assert port_again == 6000 + assert len(start_calls) == 1 + + # Modify passage data and metadata to force signature change + time.sleep(0.01) # Ensure filesystem timestamps move forward + passages_path.write_text("second\n", encoding="utf-8") + _write_meta(meta_path, passages_path.name, index_path.name, total=2) + + ok, port_third = manager.start_server( + port=6000, + model_name="test-model", + passages_file=str(meta_path), + ) + assert ok + assert port_third == 6000 + assert len(start_calls) == 2 + + updated_signature = start_calls[1]["passages_signature"] + assert updated_signature != initial_signature