Compare commits
2 Commits
fix/ask-cl
...
fix/passag
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
895dd8cd5a | ||
|
|
01ded385df |
@@ -546,9 +546,6 @@ leann search my-docs "machine learning concepts"
|
||||
# Interactive chat with your documents
|
||||
leann ask my-docs --interactive
|
||||
|
||||
# Ask a single question (non-interactive)
|
||||
leann ask my-docs "Where are prompts configured?"
|
||||
|
||||
# List all your indexes
|
||||
leann list
|
||||
|
||||
|
||||
@@ -257,11 +257,6 @@ Examples:
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||
ask_parser.add_argument("index_name", help="Index name")
|
||||
ask_parser.add_argument(
|
||||
"query",
|
||||
nargs="?",
|
||||
help="Question to ask (omit for prompt or when using --interactive)",
|
||||
)
|
||||
ask_parser.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
@@ -1536,29 +1531,7 @@ Examples:
|
||||
|
||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||
|
||||
llm_kwargs: dict[str, Any] = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
def _ask_once(prompt: str) -> None:
|
||||
response = chat.ask(
|
||||
prompt,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
initial_query = (args.query or "").strip()
|
||||
|
||||
if args.interactive:
|
||||
if initial_query:
|
||||
_ask_once(initial_query)
|
||||
|
||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||
print("=" * 40)
|
||||
|
||||
@@ -1571,14 +1544,41 @@ Examples:
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
_ask_once(user_input)
|
||||
else:
|
||||
query = initial_query or input("Enter your question: ").strip()
|
||||
if not query:
|
||||
print("No question provided. Exiting.")
|
||||
return
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
_ask_once(query)
|
||||
response = chat.ask(
|
||||
user_input,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
else:
|
||||
query = input("Enter your question: ").strip()
|
||||
if query:
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=args.prune_ratio,
|
||||
recompute_embeddings=args.recompute_embeddings,
|
||||
pruning_strategy=args.pruning_strategy,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"LEANN: {response}")
|
||||
|
||||
async def run(self, args=None):
|
||||
parser = self.create_parser()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
@@ -48,6 +49,85 @@ def _check_port(port: int) -> bool:
|
||||
# Note: All cross-process scanning helpers removed for simplicity
|
||||
|
||||
|
||||
def _safe_resolve(path: Path) -> str:
|
||||
"""Resolve paths safely even if the target does not yet exist."""
|
||||
try:
|
||||
return str(path.resolve(strict=False))
|
||||
except Exception:
|
||||
return str(path)
|
||||
|
||||
|
||||
def _safe_stat_signature(path: Path) -> dict:
|
||||
"""Return a lightweight signature describing the current state of a path."""
|
||||
signature: dict[str, object] = {"path": _safe_resolve(path)}
|
||||
try:
|
||||
stat = path.stat()
|
||||
except FileNotFoundError:
|
||||
signature["missing"] = True
|
||||
except Exception as exc: # pragma: no cover - unexpected filesystem errors
|
||||
signature["error"] = str(exc)
|
||||
else:
|
||||
signature["mtime_ns"] = stat.st_mtime_ns
|
||||
signature["size"] = stat.st_size
|
||||
return signature
|
||||
|
||||
|
||||
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
|
||||
"""Collect modification signatures for metadata and referenced passage files."""
|
||||
if not passages_file:
|
||||
return None
|
||||
|
||||
meta_path = Path(passages_file)
|
||||
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
|
||||
|
||||
try:
|
||||
with meta_path.open(encoding="utf-8") as fh:
|
||||
meta = json.load(fh)
|
||||
except FileNotFoundError:
|
||||
signature["meta_missing"] = True
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
except json.JSONDecodeError as exc:
|
||||
signature["meta_error"] = f"json_error:{exc}"
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
except Exception as exc: # pragma: no cover - unexpected errors
|
||||
signature["meta_error"] = str(exc)
|
||||
signature["sources"] = []
|
||||
return signature
|
||||
|
||||
base_dir = meta_path.parent
|
||||
seen_paths: set[str] = set()
|
||||
source_signatures: list[dict[str, object]] = []
|
||||
|
||||
for source in meta.get("passage_sources", []):
|
||||
for key, kind in (
|
||||
("path", "passages"),
|
||||
("path_relative", "passages"),
|
||||
("index_path", "index"),
|
||||
("index_path_relative", "index"),
|
||||
):
|
||||
raw_path = source.get(key)
|
||||
if not raw_path:
|
||||
continue
|
||||
candidate = Path(raw_path)
|
||||
if not candidate.is_absolute():
|
||||
candidate = base_dir / candidate
|
||||
resolved = _safe_resolve(candidate)
|
||||
if resolved in seen_paths:
|
||||
continue
|
||||
seen_paths.add(resolved)
|
||||
sig = _safe_stat_signature(candidate)
|
||||
sig["kind"] = kind
|
||||
source_signatures.append(sig)
|
||||
|
||||
signature["sources"] = source_signatures
|
||||
return signature
|
||||
|
||||
|
||||
# Note: All cross-process scanning helpers removed for simplicity
|
||||
|
||||
|
||||
class EmbeddingServerManager:
|
||||
"""
|
||||
A simplified manager for embedding server processes that avoids complex update mechanisms.
|
||||
@@ -85,13 +165,14 @@ class EmbeddingServerManager:
|
||||
"""Start the embedding server."""
|
||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||
provider_options = kwargs.pop("provider_options", None)
|
||||
passages_file = kwargs.get("passages_file", "")
|
||||
|
||||
config_signature = {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
config_signature = self._build_config_signature(
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
provider_options=provider_options,
|
||||
passages_file=passages_file,
|
||||
)
|
||||
|
||||
# If this manager already has a live server, just reuse it
|
||||
if (
|
||||
@@ -115,6 +196,7 @@ class EmbeddingServerManager:
|
||||
port,
|
||||
model_name,
|
||||
embedding_mode,
|
||||
config_signature=config_signature,
|
||||
provider_options=provider_options,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -136,11 +218,30 @@ class EmbeddingServerManager:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _build_config_signature(
|
||||
self,
|
||||
*,
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
provider_options: Optional[dict],
|
||||
passages_file: Optional[str],
|
||||
) -> dict:
|
||||
"""Create a signature describing the current server configuration."""
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"passages_file": passages_file or "",
|
||||
"embedding_mode": embedding_mode,
|
||||
"provider_options": provider_options or {},
|
||||
"passages_signature": _build_passages_signature(passages_file),
|
||||
}
|
||||
|
||||
def _start_server_colab(
|
||||
self,
|
||||
port: int,
|
||||
model_name: str,
|
||||
embedding_mode: str = "sentence-transformers",
|
||||
*,
|
||||
config_signature: Optional[dict] = None,
|
||||
provider_options: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> tuple[bool, int]:
|
||||
@@ -163,10 +264,11 @@ class EmbeddingServerManager:
|
||||
command,
|
||||
actual_port,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
)
|
||||
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||
if started:
|
||||
self._server_config = {
|
||||
self._server_config = config_signature or {
|
||||
"model_name": model_name,
|
||||
"passages_file": kwargs.get("passages_file", ""),
|
||||
"embedding_mode": embedding_mode,
|
||||
@@ -198,6 +300,7 @@ class EmbeddingServerManager:
|
||||
command,
|
||||
port,
|
||||
provider_options=provider_options,
|
||||
config_signature=config_signature,
|
||||
)
|
||||
started, ready_port = self._wait_for_server_ready(port)
|
||||
if started:
|
||||
@@ -241,7 +344,9 @@ class EmbeddingServerManager:
|
||||
self,
|
||||
command: list,
|
||||
port: int,
|
||||
*,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Launch the server process."""
|
||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||
@@ -276,26 +381,29 @@ class EmbeddingServerManager:
|
||||
)
|
||||
self.server_port = port
|
||||
# Record config for in-process reuse (best effort; refined later when ready)
|
||||
try:
|
||||
self._server_config = {
|
||||
"model_name": command[command.index("--model-name") + 1]
|
||||
if "--model-name" in command
|
||||
else "",
|
||||
"passages_file": command[command.index("--passages-file") + 1]
|
||||
if "--passages-file" in command
|
||||
else "",
|
||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||
if "--embedding-mode" in command
|
||||
else "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
except Exception:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
if config_signature is not None:
|
||||
self._server_config = config_signature
|
||||
else: # Fallback for unexpected code paths
|
||||
try:
|
||||
self._server_config = {
|
||||
"model_name": command[command.index("--model-name") + 1]
|
||||
if "--model-name" in command
|
||||
else "",
|
||||
"passages_file": command[command.index("--passages-file") + 1]
|
||||
if "--passages-file" in command
|
||||
else "",
|
||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||
if "--embedding-mode" in command
|
||||
else "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
except Exception:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||
|
||||
# Register atexit callback only when we actually start a process
|
||||
@@ -403,7 +511,9 @@ class EmbeddingServerManager:
|
||||
self,
|
||||
command: list,
|
||||
port: int,
|
||||
*,
|
||||
provider_options: Optional[dict] = None,
|
||||
config_signature: Optional[dict] = None,
|
||||
) -> None:
|
||||
"""Launch the server process with Colab-specific settings."""
|
||||
logger.info(f"Colab Command: {' '.join(command)}")
|
||||
@@ -429,12 +539,15 @@ class EmbeddingServerManager:
|
||||
atexit.register(self._finalize_process)
|
||||
self._atexit_registered = True
|
||||
# Record config for in-process reuse is best-effort in Colab mode
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
if config_signature is not None:
|
||||
self._server_config = config_signature
|
||||
else:
|
||||
self._server_config = {
|
||||
"model_name": "",
|
||||
"passages_file": "",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"provider_options": provider_options or {},
|
||||
}
|
||||
|
||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||
"""Wait for the server to be ready with Colab-specific timeout."""
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from leann.cli import LeannCLI
|
||||
|
||||
|
||||
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
cli = LeannCLI()
|
||||
parser = cli.create_parser()
|
||||
|
||||
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
|
||||
|
||||
assert args.command == "ask"
|
||||
assert args.index_name == "my-docs"
|
||||
assert args.query == "Where are prompts configured?"
|
||||
137
tests/test_embedding_server_manager.py
Normal file
137
tests/test_embedding_server_manager.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
|
||||
|
||||
class DummyProcess:
|
||||
def __init__(self):
|
||||
self.pid = 12345
|
||||
self._terminated = False
|
||||
|
||||
def poll(self):
|
||||
return 0 if self._terminated else None
|
||||
|
||||
def terminate(self):
|
||||
self._terminated = True
|
||||
|
||||
def kill(self):
|
||||
self._terminated = True
|
||||
|
||||
def wait(self, timeout=None):
|
||||
self._terminated = True
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_manager(monkeypatch):
|
||||
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
|
||||
|
||||
def fake_get_available_port(start_port):
|
||||
return start_port
|
||||
|
||||
monkeypatch.setattr(
|
||||
"leann.embedding_server_manager._get_available_port",
|
||||
fake_get_available_port,
|
||||
)
|
||||
|
||||
start_calls = []
|
||||
|
||||
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
|
||||
config_signature = kwargs.get("config_signature")
|
||||
start_calls.append(config_signature)
|
||||
self.server_process = DummyProcess()
|
||||
self.server_port = port
|
||||
self._server_config = config_signature
|
||||
return True, port
|
||||
|
||||
monkeypatch.setattr(
|
||||
EmbeddingServerManager,
|
||||
"_start_new_server",
|
||||
fake_start_new_server,
|
||||
)
|
||||
|
||||
# Ensure stop_server doesn't try to operate on real subprocesses
|
||||
def fake_stop_server(self):
|
||||
self.server_process = None
|
||||
self.server_port = None
|
||||
self._server_config = None
|
||||
|
||||
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
|
||||
|
||||
return manager, start_calls
|
||||
|
||||
|
||||
def _write_meta(meta_path, passages_name, index_name, total):
|
||||
meta_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"backend_name": "hnsw",
|
||||
"embedding_model": "test-model",
|
||||
"embedding_mode": "sentence-transformers",
|
||||
"dimensions": 3,
|
||||
"backend_kwargs": {},
|
||||
"passage_sources": [
|
||||
{
|
||||
"type": "jsonl",
|
||||
"path": passages_name,
|
||||
"index_path": index_name,
|
||||
}
|
||||
],
|
||||
"total_passages": total,
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
|
||||
manager, start_calls = embedding_manager
|
||||
|
||||
meta_path = tmp_path / "example.meta.json"
|
||||
passages_path = tmp_path / "example.passages.jsonl"
|
||||
index_path = tmp_path / "example.passages.idx"
|
||||
|
||||
passages_path.write_text("first\n", encoding="utf-8")
|
||||
index_path.write_bytes(b"index")
|
||||
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
|
||||
|
||||
# Initial start populates signature
|
||||
ok, port = manager.start_server(
|
||||
port=6000,
|
||||
model_name="test-model",
|
||||
passages_file=str(meta_path),
|
||||
)
|
||||
assert ok
|
||||
assert port == 6000
|
||||
assert len(start_calls) == 1
|
||||
|
||||
initial_signature = start_calls[0]["passages_signature"]
|
||||
|
||||
# No metadata change => reuse existing server
|
||||
ok, port_again = manager.start_server(
|
||||
port=6000,
|
||||
model_name="test-model",
|
||||
passages_file=str(meta_path),
|
||||
)
|
||||
assert ok
|
||||
assert port_again == 6000
|
||||
assert len(start_calls) == 1
|
||||
|
||||
# Modify passage data and metadata to force signature change
|
||||
time.sleep(0.01) # Ensure filesystem timestamps move forward
|
||||
passages_path.write_text("second\n", encoding="utf-8")
|
||||
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
|
||||
|
||||
ok, port_third = manager.start_server(
|
||||
port=6000,
|
||||
model_name="test-model",
|
||||
passages_file=str(meta_path),
|
||||
)
|
||||
assert ok
|
||||
assert port_third == 6000
|
||||
assert len(start_calls) == 2
|
||||
|
||||
updated_signature = start_calls[1]["passages_signature"]
|
||||
assert updated_signature != initial_signature
|
||||
Reference in New Issue
Block a user