merge main

This commit is contained in:
yichuan520030910320
2025-09-23 23:21:53 -07:00
5 changed files with 399 additions and 67 deletions

View File

@@ -182,7 +182,10 @@ LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`,
### Generation Model Setup ### Generation Model Setup
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama). #### LLM Backend
LEANN supports many LLM providers for text generation (HuggingFace, Ollama, and Any OpenAI compatible API).
<details> <details>
<summary><strong>🔑 OpenAI API Setup (Default)</strong></summary> <summary><strong>🔑 OpenAI API Setup (Default)</strong></summary>
@@ -193,6 +196,68 @@ Set your OpenAI API key as an environment variable:
export OPENAI_API_KEY="your-api-key-here" export OPENAI_API_KEY="your-api-key-here"
``` ```
Make sure to use `--llm openai` flag when using the CLI.
You can also specify the model name with `--llm-model <model-name>` flag.
</details>
<details>
<summary><strong>🛠️ Supported LLM & Embedding Providers (via OpenAI Compatibility)</strong></summary>
Thanks to the widespread adoption of the OpenAI API format, LEANN is compatible out-of-the-box with a vast array of LLM and embedding providers. Simply set the `OPENAI_BASE_URL` and `OPENAI_API_KEY` environment variables to connect to your preferred service.
```sh
export OPENAI_API_KEY="xxx"
export OPENAI_BASE_URL="http://localhost:1234/v1" # base url of the provider
```
To use OpenAI compatible endpoint with the CLI interface:
If you are using it for text generation, make sure to use `--llm openai` flag and specify the model name with `--llm-model <model-name>` flag.
If you are using it for embedding, set the `--embedding-mode openai` flag and specify the model name with `--embedding-model <MODEL>`.
-----
Below is a list of base URLs for common providers to get you started.
### 🖥️ Local Inference Engines (Recommended for full privacy)
| Provider | Sample Base URL |
| ---------------- | --------------------------- |
| **Ollama** | `http://localhost:11434/v1` |
| **LM Studio** | `http://localhost:1234/v1` |
| **vLLM** | `http://localhost:8000/v1` |
| **llama.cpp** | `http://localhost:8080/v1` |
| **SGLang** | `http://localhost:30000/v1` |
| **LiteLLM** | `http://localhost:4000` |
-----
### ☁️ Cloud Providers
> **🚨 A Note on Privacy:** Before choosing a cloud provider, carefully review their privacy and data retention policies. Depending on their terms, your data may be used for their own purposes, including but not limited to human reviews and model training, which can lead to serious consequences if not handled properly.
| Provider | Base URL |
| ---------------- | ---------------------------------------------------------- |
| **OpenAI** | `https://api.openai.com/v1` |
| **OpenRouter** | `https://openrouter.ai/api/v1` |
| **Gemini** | `https://generativelanguage.googleapis.com/v1beta/openai/` |
| **x.AI (Grok)** | `https://api.x.ai/v1` |
| **Groq AI** | `https://api.groq.com/openai/v1` |
| **DeepSeek** | `https://api.deepseek.com/v1` |
| **SiliconFlow** | `https://api.siliconflow.cn/v1` |
| **Zhipu (BigModel)** | `https://open.bigmodel.cn/api/paas/v4/` |
| **Mistral AI** | `https://api.mistral.ai/v1` |
If your provider isn't on this list, don't worry! Check their documentation for an OpenAI-compatible endpoint—chances are, it's OpenAI Compatible too!
</details> </details>
<details> <details>
@@ -546,6 +611,9 @@ leann search my-docs "machine learning concepts"
# Interactive chat with your documents # Interactive chat with your documents
leann ask my-docs --interactive leann ask my-docs --interactive
# Ask a single question (non-interactive)
leann ask my-docs "Where are prompts configured?"
# List all your indexes # List all your indexes
leann list leann list

View File

@@ -257,6 +257,11 @@ Examples:
# Ask command # Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions") ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name") ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"query",
nargs="?",
help="Question to ask (omit for prompt or when using --interactive)",
)
ask_parser.add_argument( ask_parser.add_argument(
"--llm", "--llm",
type=str, type=str,
@@ -1531,7 +1536,29 @@ Examples:
chat = LeannChat(index_path=index_path, llm_config=llm_config) chat = LeannChat(index_path=index_path, llm_config=llm_config)
llm_kwargs: dict[str, Any] = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
response = chat.ask(
prompt,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
initial_query = (args.query or "").strip()
if args.interactive: if args.interactive:
if initial_query:
_ask_once(initial_query)
print("LEANN Assistant ready! Type 'quit' to exit") print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40) print("=" * 40)
@@ -1544,41 +1571,14 @@ Examples:
if not user_input: if not user_input:
continue continue
# Prepare LLM kwargs with thinking budget if specified _ask_once(user_input)
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
else: else:
query = input("Enter your question: ").strip() query = initial_query or input("Enter your question: ").strip()
if query: if not query:
# Prepare LLM kwargs with thinking budget if specified print("No question provided. Exiting.")
llm_kwargs = {} return
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask( _ask_once(query)
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
async def run(self, args=None): async def run(self, args=None):
parser = self.create_parser() parser = self.create_parser()

View File

@@ -1,4 +1,5 @@
import atexit import atexit
import json
import logging import logging
import os import os
import socket import socket
@@ -48,6 +49,85 @@ def _check_port(port: int) -> bool:
# Note: All cross-process scanning helpers removed for simplicity # Note: All cross-process scanning helpers removed for simplicity
def _safe_resolve(path: Path) -> str:
"""Resolve paths safely even if the target does not yet exist."""
try:
return str(path.resolve(strict=False))
except Exception:
return str(path)
def _safe_stat_signature(path: Path) -> dict:
"""Return a lightweight signature describing the current state of a path."""
signature: dict[str, object] = {"path": _safe_resolve(path)}
try:
stat = path.stat()
except FileNotFoundError:
signature["missing"] = True
except Exception as exc: # pragma: no cover - unexpected filesystem errors
signature["error"] = str(exc)
else:
signature["mtime_ns"] = stat.st_mtime_ns
signature["size"] = stat.st_size
return signature
def _build_passages_signature(passages_file: Optional[str]) -> Optional[dict]:
"""Collect modification signatures for metadata and referenced passage files."""
if not passages_file:
return None
meta_path = Path(passages_file)
signature: dict[str, object] = {"meta": _safe_stat_signature(meta_path)}
try:
with meta_path.open(encoding="utf-8") as fh:
meta = json.load(fh)
except FileNotFoundError:
signature["meta_missing"] = True
signature["sources"] = []
return signature
except json.JSONDecodeError as exc:
signature["meta_error"] = f"json_error:{exc}"
signature["sources"] = []
return signature
except Exception as exc: # pragma: no cover - unexpected errors
signature["meta_error"] = str(exc)
signature["sources"] = []
return signature
base_dir = meta_path.parent
seen_paths: set[str] = set()
source_signatures: list[dict[str, object]] = []
for source in meta.get("passage_sources", []):
for key, kind in (
("path", "passages"),
("path_relative", "passages"),
("index_path", "index"),
("index_path_relative", "index"),
):
raw_path = source.get(key)
if not raw_path:
continue
candidate = Path(raw_path)
if not candidate.is_absolute():
candidate = base_dir / candidate
resolved = _safe_resolve(candidate)
if resolved in seen_paths:
continue
seen_paths.add(resolved)
sig = _safe_stat_signature(candidate)
sig["kind"] = kind
source_signatures.append(sig)
signature["sources"] = source_signatures
return signature
# Note: All cross-process scanning helpers removed for simplicity
class EmbeddingServerManager: class EmbeddingServerManager:
""" """
A simplified manager for embedding server processes that avoids complex update mechanisms. A simplified manager for embedding server processes that avoids complex update mechanisms.
@@ -85,13 +165,14 @@ class EmbeddingServerManager:
"""Start the embedding server.""" """Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here # passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None) provider_options = kwargs.pop("provider_options", None)
passages_file = kwargs.get("passages_file", "")
config_signature = { config_signature = self._build_config_signature(
"model_name": model_name, model_name=model_name,
"passages_file": kwargs.get("passages_file", ""), embedding_mode=embedding_mode,
"embedding_mode": embedding_mode, provider_options=provider_options,
"provider_options": provider_options or {}, passages_file=passages_file,
} )
# If this manager already has a live server, just reuse it # If this manager already has a live server, just reuse it
if ( if (
@@ -115,6 +196,7 @@ class EmbeddingServerManager:
port, port,
model_name, model_name,
embedding_mode, embedding_mode,
config_signature=config_signature,
provider_options=provider_options, provider_options=provider_options,
**kwargs, **kwargs,
) )
@@ -136,11 +218,30 @@ class EmbeddingServerManager:
**kwargs, **kwargs,
) )
def _build_config_signature(
self,
*,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict],
passages_file: Optional[str],
) -> dict:
"""Create a signature describing the current server configuration."""
return {
"model_name": model_name,
"passages_file": passages_file or "",
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
"passages_signature": _build_passages_signature(passages_file),
}
def _start_server_colab( def _start_server_colab(
self, self,
port: int, port: int,
model_name: str, model_name: str,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
*,
config_signature: Optional[dict] = None,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
@@ -163,10 +264,11 @@ class EmbeddingServerManager:
command, command,
actual_port, actual_port,
provider_options=provider_options, provider_options=provider_options,
config_signature=config_signature,
) )
started, ready_port = self._wait_for_server_ready_colab(actual_port) started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started: if started:
self._server_config = { self._server_config = config_signature or {
"model_name": model_name, "model_name": model_name,
"passages_file": kwargs.get("passages_file", ""), "passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode, "embedding_mode": embedding_mode,
@@ -198,6 +300,7 @@ class EmbeddingServerManager:
command, command,
port, port,
provider_options=provider_options, provider_options=provider_options,
config_signature=config_signature,
) )
started, ready_port = self._wait_for_server_ready(port) started, ready_port = self._wait_for_server_ready(port)
if started: if started:
@@ -241,7 +344,9 @@ class EmbeddingServerManager:
self, self,
command: list, command: list,
port: int, port: int,
*,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None: ) -> None:
"""Launch the server process.""" """Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
@@ -276,26 +381,29 @@ class EmbeddingServerManager:
) )
self.server_port = port self.server_port = port
# Record config for in-process reuse (best effort; refined later when ready) # Record config for in-process reuse (best effort; refined later when ready)
try: if config_signature is not None:
self._server_config = { self._server_config = config_signature
"model_name": command[command.index("--model-name") + 1] else: # Fallback for unexpected code paths
if "--model-name" in command try:
else "", self._server_config = {
"passages_file": command[command.index("--passages-file") + 1] "model_name": command[command.index("--model-name") + 1]
if "--passages-file" in command if "--model-name" in command
else "", else "",
"embedding_mode": command[command.index("--embedding-mode") + 1] "passages_file": command[command.index("--passages-file") + 1]
if "--embedding-mode" in command if "--passages-file" in command
else "sentence-transformers", else "",
"provider_options": provider_options or {}, "embedding_mode": command[command.index("--embedding-mode") + 1]
} if "--embedding-mode" in command
except Exception: else "sentence-transformers",
self._server_config = { "provider_options": provider_options or {},
"model_name": "", }
"passages_file": "", except Exception:
"embedding_mode": "sentence-transformers", self._server_config = {
"provider_options": provider_options or {}, "model_name": "",
} "passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
# Register atexit callback only when we actually start a process # Register atexit callback only when we actually start a process
@@ -403,7 +511,9 @@ class EmbeddingServerManager:
self, self,
command: list, command: list,
port: int, port: int,
*,
provider_options: Optional[dict] = None, provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
) -> None: ) -> None:
"""Launch the server process with Colab-specific settings.""" """Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}") logger.info(f"Colab Command: {' '.join(command)}")
@@ -429,12 +539,15 @@ class EmbeddingServerManager:
atexit.register(self._finalize_process) atexit.register(self._finalize_process)
self._atexit_registered = True self._atexit_registered = True
# Record config for in-process reuse is best-effort in Colab mode # Record config for in-process reuse is best-effort in Colab mode
self._server_config = { if config_signature is not None:
"model_name": "", self._server_config = config_signature
"passages_file": "", else:
"embedding_mode": "sentence-transformers", self._server_config = {
"provider_options": provider_options or {}, "model_name": "",
} "passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
"""Wait for the server to be ready with Colab-specific timeout.""" """Wait for the server to be ready with Colab-specific timeout."""

14
tests/test_cli_ask.py Normal file
View File

@@ -0,0 +1,14 @@
from leann.cli import LeannCLI
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
cli = LeannCLI()
parser = cli.create_parser()
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
assert args.command == "ask"
assert args.index_name == "my-docs"
assert args.query == "Where are prompts configured?"

View File

@@ -0,0 +1,137 @@
import json
import time
import pytest
from leann.embedding_server_manager import EmbeddingServerManager
class DummyProcess:
def __init__(self):
self.pid = 12345
self._terminated = False
def poll(self):
return 0 if self._terminated else None
def terminate(self):
self._terminated = True
def kill(self):
self._terminated = True
def wait(self, timeout=None):
self._terminated = True
return 0
@pytest.fixture
def embedding_manager(monkeypatch):
manager = EmbeddingServerManager("leann_backend_hnsw.hnsw_embedding_server")
def fake_get_available_port(start_port):
return start_port
monkeypatch.setattr(
"leann.embedding_server_manager._get_available_port",
fake_get_available_port,
)
start_calls = []
def fake_start_new_server(self, port, model_name, embedding_mode, **kwargs):
config_signature = kwargs.get("config_signature")
start_calls.append(config_signature)
self.server_process = DummyProcess()
self.server_port = port
self._server_config = config_signature
return True, port
monkeypatch.setattr(
EmbeddingServerManager,
"_start_new_server",
fake_start_new_server,
)
# Ensure stop_server doesn't try to operate on real subprocesses
def fake_stop_server(self):
self.server_process = None
self.server_port = None
self._server_config = None
monkeypatch.setattr(EmbeddingServerManager, "stop_server", fake_stop_server)
return manager, start_calls
def _write_meta(meta_path, passages_name, index_name, total):
meta_path.write_text(
json.dumps(
{
"backend_name": "hnsw",
"embedding_model": "test-model",
"embedding_mode": "sentence-transformers",
"dimensions": 3,
"backend_kwargs": {},
"passage_sources": [
{
"type": "jsonl",
"path": passages_name,
"index_path": index_name,
}
],
"total_passages": total,
}
),
encoding="utf-8",
)
def test_server_restarts_when_metadata_changes(tmp_path, embedding_manager):
manager, start_calls = embedding_manager
meta_path = tmp_path / "example.meta.json"
passages_path = tmp_path / "example.passages.jsonl"
index_path = tmp_path / "example.passages.idx"
passages_path.write_text("first\n", encoding="utf-8")
index_path.write_bytes(b"index")
_write_meta(meta_path, passages_path.name, index_path.name, total=1)
# Initial start populates signature
ok, port = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port == 6000
assert len(start_calls) == 1
initial_signature = start_calls[0]["passages_signature"]
# No metadata change => reuse existing server
ok, port_again = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_again == 6000
assert len(start_calls) == 1
# Modify passage data and metadata to force signature change
time.sleep(0.01) # Ensure filesystem timestamps move forward
passages_path.write_text("second\n", encoding="utf-8")
_write_meta(meta_path, passages_path.name, index_path.name, total=2)
ok, port_third = manager.start_server(
port=6000,
model_name="test-model",
passages_file=str(meta_path),
)
assert ok
assert port_third == 6000
assert len(start_calls) == 2
updated_signature = start_calls[1]["passages_signature"]
assert updated_signature != initial_signature