Compare commits

..

1 Commits

Author SHA1 Message Date
Andy Lee
47aeb85f82 Allow 'leann ask' to accept a positional question 2025-09-23 15:18:51 -07:00
5 changed files with 83 additions and 316 deletions

View File

@@ -546,6 +546,9 @@ leann search my-docs "machine learning concepts"
# Interactive chat with your documents # Interactive chat with your documents
leann ask my-docs --interactive leann ask my-docs --interactive
# Ask a single question (non-interactive)
leann ask my-docs "Where are prompts configured?"
# List all your indexes # List all your indexes
leann list leann list

View File

@@ -257,6 +257,11 @@ Examples:
# Ask command # Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions") ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name") ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"query",
nargs="?",
help="Question to ask (omit for prompt or when using --interactive)",
)
ask_parser.add_argument( ask_parser.add_argument(
"--llm", "--llm",
type=str, type=str,
@@ -1531,7 +1536,29 @@ Examples:
chat = LeannChat(index_path=index_path, llm_config=llm_config) chat = LeannChat(index_path=index_path, llm_config=llm_config)
llm_kwargs: dict[str, Any] = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
response = chat.ask(
prompt,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
initial_query = (args.query or "").strip()
if args.interactive: if args.interactive:
if initial_query:
_ask_once(initial_query)
print("LEANN Assistant ready! Type 'quit' to exit") print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40) print("=" * 40)
@@ -1544,41 +1571,14 @@ Examples:
if not user_input: if not user_input:
continue continue
# Prepare LLM kwargs with thinking budget if specified _ask_once(user_input)
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
else: else:
query = input("Enter your question: ").strip() query = initial_query or input("Enter your question: ").strip()
if query: if not query:
# Prepare LLM kwargs with thinking budget if specified print("No question provided. Exiting.")
llm_kwargs = {} return
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask( _ask_once(query)
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
async def run(self, args=None): async def run(self, args=None):
parser = self.create_parser() parser = self.create_parser()

View File

@@ -1,5 +1,4 @@
import atexit import atexit
import json
import logging import logging
import os import os
import socket import socket
@@ -49,85 +48,6 @@ def _check_port(port: int) -> bool:
# Note: All cross-process scanning helpers removed for simplicity # Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager: class EmbeddingServerManager:
""" """
A simplified manager for embedding server processes that avoids complex update mechanisms. A simplified manager for embedding server processes that avoids complex update mechanisms.
@@ -165,14 +85,13 @@ class EmbeddingServerManager:
"""Start the embedding server.""" """Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here # passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None) provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
config_signature = self._build_config_signature( config_signature = {
model_name=model_name, "model_name": model_name,
embedding_mode=embedding_mode, "passages_file": kwargs.get("passages_file", ""),
provider_options=provider_options, "embedding_mode": embedding_mode,
passages_file=passages_file, "provider_options": provider_options or {},
) }
# If this manager already has a live server, just reuse it # If this manager already has a live server, just reuse it
if ( if (
@@ -196,7 +115,6 @@ class EmbeddingServerManager:
port, port,
model_name, model_name,
embedding_mode, embedding_mode,
config_signature=config_signature,
provider_options=provider_options, provider_options=provider_options,
**kwargs, **kwargs,
) )
@@ -218,30 +136,11 @@ class EmbeddingServerManager:
**kwargs, **kwargs,
) )
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
def _start_server_colab( def _start_server_colab(
self, self,
port: int, port: int,
model_name: str, model_name: str,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
@@ -264,11 +163,10 @@ class EmbeddingServerManager:
command, command,
actual_port, actual_port,
provider_options=provider_options, provider_options=provider_options,
config_signature=config_signature,
) )
started, ready_port = self._wait_for_server_ready_colab(actual_port) started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started: if started:
self._server_config = config_signature or { self._server_config = {
"model_name": model_name, "model_name": model_name,
"passages_file": kwargs.get("passages_file", ""), "passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode, "embedding_mode": embedding_mode,
@@ -300,7 +198,6 @@ class EmbeddingServerManager:
command, command,
port, port,
provider_options=provider_options, provider_options=provider_options,
config_signature=config_signature,
) )
started, ready_port = self._wait_for_server_ready(port) started, ready_port = self._wait_for_server_ready(port)
if started: if started:
@@ -344,9 +241,7 @@ class EmbeddingServerManager:
self, self,
command: list, command: list,
port: int, port: int,
*,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None: ) -> None:
"""Launch the server process.""" """Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
@@ -381,29 +276,26 @@ class EmbeddingServerManager:
) )
self.server_port = port self.server_port = port
# Record config for in-process reuse (best effort; refined later when ready) # Record config for in-process reuse (best effort; refined later when ready)
if config_signature is not None: try:
self._server_config = config_signature self._server_config = {
else: # Fallback for unexpected code paths "model_name": command[command.index("--model-name") + 1]
try: if "--model-name" in command
self._server_config = { else "",
"model_name": command[command.index("--model-name") + 1] "passages_file": command[command.index("--passages-file") + 1]
if "--model-name" in command if "--passages-file" in command
else "", else "",
"passages_file": command[command.index("--passages-file") + 1] "embedding_mode": command[command.index("--embedding-mode") + 1]
if "--passages-file" in command if "--embedding-mode" in command
else "", else "sentence-transformers",
"embedding_mode": command[command.index("--embedding-mode") + 1] "provider_options": provider_options or {},
if "--embedding-mode" in command }
else "sentence-transformers", except Exception:
"provider_options": provider_options or {}, self._server_config = {
} "model_name": "",
except Exception: "passages_file": "",
self._server_config = { "embedding_mode": "sentence-transformers",
"model_name": "", "provider_options": provider_options or {},
"passages_file": "", }
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process # Register atexit callback only when we actually start a process
@@ -511,9 +403,7 @@ class EmbeddingServerManager:
self, self,
command: list, command: list,
port: int, port: int,
*,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None: ) -> None:
"""Launch the server process with Colab-specific settings.""" """Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}") logger.info(f"Colab Command: {' '.join(command)}")
@@ -539,15 +429,12 @@ class EmbeddingServerManager:
atexit.register(self._finalize_process) atexit.register(self._finalize_process)
self._atexit_registered = True self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode # Record config for in-process reuse is best-effort in Colab mode
if config_signature is not None: self._server_config = {
self._server_config = config_signature "model_name": "",
else: "passages_file": "",
self._server_config = { "embedding_mode": "sentence-transformers",
"model_name": "", "provider_options": provider_options or {},
"passages_file": "", }
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout.""" """Wait for the server to be ready with Colab-specific timeout."""

14
tests/test_cli_ask.py Normal file
View File

@@ -0,0 +1,14 @@
from leann.cli import LeannCLI
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
cli = LeannCLI()
parser = cli.create_parser()
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
assert args.command == "ask"
assert args.index_name == "my-docs"
assert args.query == "Where are prompts configured?"

View File

@@ -1,137 +0,0 @@
import json
import time
import pytest
from leann.embedding_server_manager import EmbeddingServerManager
class DummyProcess:
def __init__(self):
self.pid = 12345
self._terminated = False
def poll(self):
return 0 if self._terminated else None
def terminate(self):
self._terminated = True
def kill(self):
self._terminated = True
def wait(self, timeout=None):
self._terminated = True
return 0
@pytest.fixture
def embedding_manager(monkeypatch):
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
def fake_get_available_port(start_port):
return start_port
monkeypatch.setattr(
"leann.embedding_server_manager._get_available_port",
fake_get_available_port,
)
start_calls = []
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
config_signature = kwargs.get("config_signature")
start_calls.append(config_signature)
self.server_process = DummyProcess()
self.server_port = port
self._server_config = config_signature
return True, port
monkeypatch.setattr(
EmbeddingServerManager,
"_start_new_server",
fake_start_new_server,
)
# Ensure stop_server doesn't try to operate on real subprocesses
def fake_stop_server(self):
self.server_process = None
self.server_port = None
self._server_config = None
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
return manager, start_calls
def _write_meta(meta_path, passages_name, index_name, total):
meta_path.write_text(
json.dumps(
{
"backend_name": "hnsw",
"embedding_model": "test-model",
"embedding_mode": "sentence-transformers",
"dimensions": 3,
"backend_kwargs": {},
"passage_sources": [
{
"type": "jsonl",
"path": passages_name,
"index_path": index_name,
}
],
"total_passages": total,
}
),
encoding="utf-8",
)
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
manager, start_calls = embedding_manager
meta_path = tmp_path / "example.meta.json"
passages_path = tmp_path / "example.passages.jsonl"
index_path = tmp_path / "example.passages.idx"
passages_path.write_text("first\n", encoding="utf-8")
index_path.write_bytes(b"index")
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
# Initial start populates signature
ok, port = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port == 6000
assert len(start_calls) == 1
initial_signature = start_calls[0]["passages_signature"]
# No metadata change => reuse existing server
ok, port_again = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_again == 6000
assert len(start_calls) == 1
# Modify passage data and metadata to force signature change
time.sleep(0.01) # Ensure filesystem timestamps move forward
passages_path.write_text("second\n", encoding="utf-8")
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
ok, port_third = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_third == 6000
assert len(start_calls) == 2
updated_signature = start_calls[1]["passages_signature"]
assert updated_signature != initial_signature