Compare commits
2 Commits
fix/drop-p
...
fix/passag
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
895dd8cd5a | ||
|
|
01ded385df |
@@ -1,4 +1,5 @@
|
|||||||
import atexit
|
import atexit
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
@@ -48,6 +49,85 @@ def _check_port(port: int) -> bool:
|
|||||||
# Note: All cross-process scanning helpers removed for simplicity
|
# Note: All cross-process scanning helpers removed for simplicity
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_resolve(path: Path) -> str:
|
||||||
|
"""Resolve paths safely even if the target does not yet exist."""
|
||||||
|
try:
|
||||||
|
return str(path.resolve(strict=False))
|
||||||
|
except Exception:
|
||||||
|
return str(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_stat_signature(path: Path) -> dict:
|
||||||
|
"""Return a lightweight signature describing the current state of a path."""
|
||||||
|
signature: dict[str, object] = {"path": _safe_resolve(path)}
|
||||||
|
try:
|
||||||
|
stat = path.stat()
|
||||||
|
except FileNotFoundError:
|
||||||
|
signature["missing"] = True
|
||||||
|
except Exception as exc: # pragma: no cover - unexpected filesystem errors
|
||||||
|
signature["error"] = str(exc)
|
||||||
|
else:
|
||||||
|
signature["mtime_ns"] = stat.st_mtime_ns
|
||||||
|
signature["size"] = stat.st_size
|
||||||
|
return signature
|
||||||
|
|
||||||
|
|
||||||
|
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
|
||||||
|
"""Collect modification signatures for metadata and referenced passage files."""
|
||||||
|
if not passages_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
meta_path = Path(passages_file)
|
||||||
|
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with meta_path.open(encoding="utf-8") as fh:
|
||||||
|
meta = json.load(fh)
|
||||||
|
except FileNotFoundError:
|
||||||
|
signature["meta_missing"] = True
|
||||||
|
signature["sources"] = []
|
||||||
|
return signature
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
signature["meta_error"] = f"json_error:{exc}"
|
||||||
|
signature["sources"] = []
|
||||||
|
return signature
|
||||||
|
except Exception as exc: # pragma: no cover - unexpected errors
|
||||||
|
signature["meta_error"] = str(exc)
|
||||||
|
signature["sources"] = []
|
||||||
|
return signature
|
||||||
|
|
||||||
|
base_dir = meta_path.parent
|
||||||
|
seen_paths: set[str] = set()
|
||||||
|
source_signatures: list[dict[str, object]] = []
|
||||||
|
|
||||||
|
for source in meta.get("passage_sources", []):
|
||||||
|
for key, kind in (
|
||||||
|
("path", "passages"),
|
||||||
|
("path_relative", "passages"),
|
||||||
|
("index_path", "index"),
|
||||||
|
("index_path_relative", "index"),
|
||||||
|
):
|
||||||
|
raw_path = source.get(key)
|
||||||
|
if not raw_path:
|
||||||
|
continue
|
||||||
|
candidate = Path(raw_path)
|
||||||
|
if not candidate.is_absolute():
|
||||||
|
candidate = base_dir / candidate
|
||||||
|
resolved = _safe_resolve(candidate)
|
||||||
|
if resolved in seen_paths:
|
||||||
|
continue
|
||||||
|
seen_paths.add(resolved)
|
||||||
|
sig = _safe_stat_signature(candidate)
|
||||||
|
sig["kind"] = kind
|
||||||
|
source_signatures.append(sig)
|
||||||
|
|
||||||
|
signature["sources"] = source_signatures
|
||||||
|
return signature
|
||||||
|
|
||||||
|
|
||||||
|
# Note: All cross-process scanning helpers removed for simplicity
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingServerManager:
|
class EmbeddingServerManager:
|
||||||
"""
|
"""
|
||||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||||
@@ -85,13 +165,14 @@ class EmbeddingServerManager:
|
|||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||||
provider_options = kwargs.pop("provider_options", None)
|
provider_options = kwargs.pop("provider_options", None)
|
||||||
|
passages_file = kwargs.get("passages_file", "")
|
||||||
|
|
||||||
config_signature = {
|
config_signature = self._build_config_signature(
|
||||||
"model_name": model_name,
|
model_name=model_name,
|
||||||
"passages_file": kwargs.get("passages_file", ""),
|
embedding_mode=embedding_mode,
|
||||||
"embedding_mode": embedding_mode,
|
provider_options=provider_options,
|
||||||
"provider_options": provider_options or {},
|
passages_file=passages_file,
|
||||||
}
|
)
|
||||||
|
|
||||||
# If this manager already has a live server, just reuse it
|
# If this manager already has a live server, just reuse it
|
||||||
if (
|
if (
|
||||||
@@ -115,6 +196,7 @@ class EmbeddingServerManager:
|
|||||||
port,
|
port,
|
||||||
model_name,
|
model_name,
|
||||||
embedding_mode,
|
embedding_mode,
|
||||||
|
config_signature=config_signature,
|
||||||
provider_options=provider_options,
|
provider_options=provider_options,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -136,11 +218,30 @@ class EmbeddingServerManager:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _build_config_signature(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
provider_options: Optional[dict],
|
||||||
|
passages_file: Optional[str],
|
||||||
|
) -> dict:
|
||||||
|
"""Create a signature describing the current server configuration."""
|
||||||
|
return {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": passages_file or "",
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
"passages_signature": _build_passages_signature(passages_file),
|
||||||
|
}
|
||||||
|
|
||||||
def _start_server_colab(
|
def _start_server_colab(
|
||||||
self,
|
self,
|
||||||
port: int,
|
port: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
*,
|
||||||
|
config_signature: Optional[dict] = None,
|
||||||
provider_options: Optional[dict] = None,
|
provider_options: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
@@ -163,10 +264,11 @@ class EmbeddingServerManager:
|
|||||||
command,
|
command,
|
||||||
actual_port,
|
actual_port,
|
||||||
provider_options=provider_options,
|
provider_options=provider_options,
|
||||||
|
config_signature=config_signature,
|
||||||
)
|
)
|
||||||
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||||
if started:
|
if started:
|
||||||
self._server_config = {
|
self._server_config = config_signature or {
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"passages_file": kwargs.get("passages_file", ""),
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
"embedding_mode": embedding_mode,
|
"embedding_mode": embedding_mode,
|
||||||
@@ -198,6 +300,7 @@ class EmbeddingServerManager:
|
|||||||
command,
|
command,
|
||||||
port,
|
port,
|
||||||
provider_options=provider_options,
|
provider_options=provider_options,
|
||||||
|
config_signature=config_signature,
|
||||||
)
|
)
|
||||||
started, ready_port = self._wait_for_server_ready(port)
|
started, ready_port = self._wait_for_server_ready(port)
|
||||||
if started:
|
if started:
|
||||||
@@ -241,7 +344,9 @@ class EmbeddingServerManager:
|
|||||||
self,
|
self,
|
||||||
command: list,
|
command: list,
|
||||||
port: int,
|
port: int,
|
||||||
|
*,
|
||||||
provider_options: Optional[dict] = None,
|
provider_options: Optional[dict] = None,
|
||||||
|
config_signature: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Launch the server process."""
|
"""Launch the server process."""
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
@@ -276,26 +381,29 @@ class EmbeddingServerManager:
|
|||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
# Record config for in-process reuse (best effort; refined later when ready)
|
# Record config for in-process reuse (best effort; refined later when ready)
|
||||||
try:
|
if config_signature is not None:
|
||||||
self._server_config = {
|
self._server_config = config_signature
|
||||||
"model_name": command[command.index("--model-name") + 1]
|
else: # Fallback for unexpected code paths
|
||||||
if "--model-name" in command
|
try:
|
||||||
else "",
|
self._server_config = {
|
||||||
"passages_file": command[command.index("--passages-file") + 1]
|
"model_name": command[command.index("--model-name") + 1]
|
||||||
if "--passages-file" in command
|
if "--model-name" in command
|
||||||
else "",
|
else "",
|
||||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
"passages_file": command[command.index("--passages-file") + 1]
|
||||||
if "--embedding-mode" in command
|
if "--passages-file" in command
|
||||||
else "sentence-transformers",
|
else "",
|
||||||
"provider_options": provider_options or {},
|
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||||
}
|
if "--embedding-mode" in command
|
||||||
except Exception:
|
else "sentence-transformers",
|
||||||
self._server_config = {
|
"provider_options": provider_options or {},
|
||||||
"model_name": "",
|
}
|
||||||
"passages_file": "",
|
except Exception:
|
||||||
"embedding_mode": "sentence-transformers",
|
self._server_config = {
|
||||||
"provider_options": provider_options or {},
|
"model_name": "",
|
||||||
}
|
"passages_file": "",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
# Register atexit callback only when we actually start a process
|
# Register atexit callback only when we actually start a process
|
||||||
@@ -403,7 +511,9 @@ class EmbeddingServerManager:
|
|||||||
self,
|
self,
|
||||||
command: list,
|
command: list,
|
||||||
port: int,
|
port: int,
|
||||||
|
*,
|
||||||
provider_options: Optional[dict] = None,
|
provider_options: Optional[dict] = None,
|
||||||
|
config_signature: Optional[dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
logger.info(f"Colab Command: {' '.join(command)}")
|
logger.info(f"Colab Command: {' '.join(command)}")
|
||||||
@@ -429,12 +539,15 @@ class EmbeddingServerManager:
|
|||||||
atexit.register(self._finalize_process)
|
atexit.register(self._finalize_process)
|
||||||
self._atexit_registered = True
|
self._atexit_registered = True
|
||||||
# Record config for in-process reuse is best-effort in Colab mode
|
# Record config for in-process reuse is best-effort in Colab mode
|
||||||
self._server_config = {
|
if config_signature is not None:
|
||||||
"model_name": "",
|
self._server_config = config_signature
|
||||||
"passages_file": "",
|
else:
|
||||||
"embedding_mode": "sentence-transformers",
|
self._server_config = {
|
||||||
"provider_options": provider_options or {},
|
"model_name": "",
|
||||||
}
|
"passages_file": "",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||||
|
|||||||
137
tests/test_embedding_server_manager.py
Normal file
137
tests/test_embedding_server_manager.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
|
||||||
|
|
||||||
|
class DummyProcess:
|
||||||
|
def __init__(self):
|
||||||
|
self.pid = 12345
|
||||||
|
self._terminated = False
|
||||||
|
|
||||||
|
def poll(self):
|
||||||
|
return 0 if self._terminated else None
|
||||||
|
|
||||||
|
def terminate(self):
|
||||||
|
self._terminated = True
|
||||||
|
|
||||||
|
def kill(self):
|
||||||
|
self._terminated = True
|
||||||
|
|
||||||
|
def wait(self, timeout=None):
|
||||||
|
self._terminated = True
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def embedding_manager(monkeypatch):
|
||||||
|
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
|
||||||
|
|
||||||
|
def fake_get_available_port(start_port):
|
||||||
|
return start_port
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"leann.embedding_server_manager._get_available_port",
|
||||||
|
fake_get_available_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_calls = []
|
||||||
|
|
||||||
|
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
|
||||||
|
config_signature = kwargs.get("config_signature")
|
||||||
|
start_calls.append(config_signature)
|
||||||
|
self.server_process = DummyProcess()
|
||||||
|
self.server_port = port
|
||||||
|
self._server_config = config_signature
|
||||||
|
return True, port
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
EmbeddingServerManager,
|
||||||
|
"_start_new_server",
|
||||||
|
fake_start_new_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure stop_server doesn't try to operate on real subprocesses
|
||||||
|
def fake_stop_server(self):
|
||||||
|
self.server_process = None
|
||||||
|
self.server_port = None
|
||||||
|
self._server_config = None
|
||||||
|
|
||||||
|
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
|
||||||
|
|
||||||
|
return manager, start_calls
|
||||||
|
|
||||||
|
|
||||||
|
def _write_meta(meta_path, passages_name, index_name, total):
|
||||||
|
meta_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"backend_name": "hnsw",
|
||||||
|
"embedding_model": "test-model",
|
||||||
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"dimensions": 3,
|
||||||
|
"backend_kwargs": {},
|
||||||
|
"passage_sources": [
|
||||||
|
{
|
||||||
|
"type": "jsonl",
|
||||||
|
"path": passages_name,
|
||||||
|
"index_path": index_name,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total_passages": total,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
|
||||||
|
manager, start_calls = embedding_manager
|
||||||
|
|
||||||
|
meta_path = tmp_path / "example.meta.json"
|
||||||
|
passages_path = tmp_path / "example.passages.jsonl"
|
||||||
|
index_path = tmp_path / "example.passages.idx"
|
||||||
|
|
||||||
|
passages_path.write_text("first\n", encoding="utf-8")
|
||||||
|
index_path.write_bytes(b"index")
|
||||||
|
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
|
||||||
|
|
||||||
|
# Initial start populates signature
|
||||||
|
ok, port = manager.start_server(
|
||||||
|
port=6000,
|
||||||
|
model_name="test-model",
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
)
|
||||||
|
assert ok
|
||||||
|
assert port == 6000
|
||||||
|
assert len(start_calls) == 1
|
||||||
|
|
||||||
|
initial_signature = start_calls[0]["passages_signature"]
|
||||||
|
|
||||||
|
# No metadata change => reuse existing server
|
||||||
|
ok, port_again = manager.start_server(
|
||||||
|
port=6000,
|
||||||
|
model_name="test-model",
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
)
|
||||||
|
assert ok
|
||||||
|
assert port_again == 6000
|
||||||
|
assert len(start_calls) == 1
|
||||||
|
|
||||||
|
# Modify passage data and metadata to force signature change
|
||||||
|
time.sleep(0.01) # Ensure filesystem timestamps move forward
|
||||||
|
passages_path.write_text("second\n", encoding="utf-8")
|
||||||
|
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
|
||||||
|
|
||||||
|
ok, port_third = manager.start_server(
|
||||||
|
port=6000,
|
||||||
|
model_name="test-model",
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
)
|
||||||
|
assert ok
|
||||||
|
assert port_third == 6000
|
||||||
|
assert len(start_calls) == 2
|
||||||
|
|
||||||
|
updated_signature = start_calls[1]["passages_signature"]
|
||||||
|
assert updated_signature != initial_signature
|
||||||
Reference in New Issue
Block a user