Compare commits

...

12 Commits

Author SHA1 Message Date
Andy Lee
47aeb85f82 Allow 'leann ask' to accept a positional question 2025-09-23 15:18:51 -07:00
Andy Lee
db7ba27ff6 feat: Add support for configurable local LLM endpoints (#115)
* feat: support configurable local llm endpoints

* docs
2025-09-23 15:12:13 -07:00
Andy Lee
5f7806e16f Introducing dynamic index update (#108)
* feat: Add GitHub PR and issue templates for better contributor experience

* simplify: Make templates more concise and user-friendly

* fix: enable is_compact=False, is_recompute=True

* feat: update when recompute

* test

* fix: real recompute

* refactor

* fix: compare with no-recompute

* fix: test
2025-09-21 22:56:27 -07:00
yichuan-w
d034e2195b fix build from source in diskann 2025-09-20 19:52:29 +00:00
yichuan520030910320
43894ff605 update submodule 2025-09-19 17:03:55 -07:00
yichuan520030910320
10311cc611 change the submodule for easy pull 2025-09-19 17:02:09 -07:00
Andy Lee
ad0d2faabc feat: Add GitHub PR and issue templates (#105)
* feat: Add GitHub PR and issue templates for better contributor experience

* simplify: Make templates more concise and user-friendly
2025-09-19 13:51:36 -07:00
Andy Lee
e93c0dec6f [Fix] Enable AST chunking when installed (package chunking utils) (#101)
* fix(core): package chunking utils for AST chunking; re-export in apps; CLI imports packaged utils

* style

* chore: fix ruff warnings (RUF059, F401)

* style
2025-09-17 18:44:00 -07:00
GitHub Actions
c5a29f849a chore: release v0.3.4 2025-09-16 20:45:22 +00:00
Yichuan Wang
3b8dc6368e Ast fork (#92) 2025-09-08 18:43:31 -07:00
Aiden Huang
e309f292de docs(mcp): add root llms.txt for MCP discovery; update MCP README to reference it; refs #76 (#91) 2025-09-07 14:39:58 -07:00
AWS Mcleod
0d9f92ea0f Add grep search functionality - Issue #86 (#87)
* Add grep search functionality to LeannSearcher

- Add use_grep parameter to search method
- Implement grep-based search on .jsonl files
- Add fallback Python regex search
- Support same SearchResult format as semantic search

Addresses issue #86

* fix: resolve linting errors

* docs: add grep search example

* docs: add grep search to README examples

* refactor: remove regex fallback, move grep example to features section

* docs: add grep search to Advanced Features with comprehensive guide
2025-09-05 13:48:07 -07:00
39 changed files with 5656 additions and 3945 deletions

50
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@@ -0,0 +1,50 @@
name: Bug Report
description: Report a bug in LEANN
labels: ["bug"]
body:
- type: textarea
id: description
attributes:
label: What happened?
description: A clear description of the bug
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: How to reproduce
placeholder: |
1. Install with...
2. Run command...
3. See error
validations:
required: true
- type: textarea
id: error
attributes:
label: Error message
description: Paste any error messages
render: shell
- type: input
id: version
attributes:
label: LEANN Version
placeholder: "0.1.0"
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating System
options:
- macOS
- Linux
- Windows
- Docker
validations:
required: true

8
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
blank_issues_enabled: true
contact_links:
- name: Documentation
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
about: Read the docs first
- name: Discussions
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
about: Ask questions and share ideas

View File

@@ -0,0 +1,27 @@
name: Feature Request
description: Suggest a new feature for LEANN
labels: ["enhancement"]
body:
- type: textarea
id: problem
attributes:
label: What problem does this solve?
description: Describe the problem or need
validations:
required: true
- type: textarea
id: solution
attributes:
label: Proposed solution
description: How would you like this to work?
validations:
required: true
- type: textarea
id: example
attributes:
label: Example usage
description: Show how the API might look
render: python

13
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,13 @@
## What does this PR do?
<!-- Brief description of your changes -->
## Related Issues
Fixes #
## Checklist
- [ ] Tests pass (`uv run pytest`)
- [ ] Code formatted (`ruff format` and `ruff check`)
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)

1
.gitignore vendored
View File

@@ -22,6 +22,7 @@ demo/experiment_results/**/*.json
*.sh *.sh
*.txt *.txt
!CMakeLists.txt !CMakeLists.txt
!llms.txt
latency_breakdown*.json latency_breakdown*.json
experiment_results/eval_results/diskann/*.json experiment_results/eval_results/diskann/*.json
aws/ aws/

3
.gitmodules vendored
View File

@@ -14,3 +14,6 @@
[submodule "packages/leann-backend-hnsw/third_party/libzmq"] [submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git url = https://github.com/zeromq/libzmq.git
[submodule "packages/astchunk-leann"]
path = packages/astchunk-leann
url = https://github.com/yichuan-w/astchunk-leann.git

View File

@@ -546,6 +546,9 @@ leann search my-docs "machine learning concepts"
# Interactive chat with your documents # Interactive chat with your documents
leann ask my-docs --interactive leann ask my-docs --interactive
# Ask a single question (non-interactive)
leann ask my-docs "Where are prompts configured?"
# List all your indexes # List all your indexes
leann list leann list
@@ -656,6 +659,19 @@ results = searcher.search(
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)** 📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
### 🔍 Grep Search
For exact text matching instead of semantic search, use the `use_grep` parameter:
```python
# Exact text search
results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
```
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
📖 **[Complete grep search guide →](docs/grep_search.md)**
## 🏗️ Architecture & How It Works ## 🏗️ Architecture & How It Works
<p align="center"> <p align="center">

View File

@@ -11,6 +11,7 @@ from typing import Any
import dotenv import dotenv
from leann.api import LeannBuilder, LeannChat from leann.api import LeannBuilder, LeannChat
from leann.registry import register_project_directory from leann.registry import register_project_directory
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
dotenv.load_dotenv() dotenv.load_dotenv()
@@ -78,6 +79,24 @@ class BaseRAGExample(ABC):
choices=["sentence-transformers", "openai", "mlx", "ollama"], choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama", help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
) )
embedding_group.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
embedding_group.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
embedding_group.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
# LLM parameters # LLM parameters
llm_group = parser.add_argument_group("LLM Parameters") llm_group = parser.add_argument_group("LLM Parameters")
@@ -97,8 +116,8 @@ class BaseRAGExample(ABC):
llm_group.add_argument( llm_group.add_argument(
"--llm-host", "--llm-host",
type=str, type=str,
default="http://localhost:11434", default=None,
help="Host for Ollama API (default: http://localhost:11434)", help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
) )
llm_group.add_argument( llm_group.add_argument(
"--thinking-budget", "--thinking-budget",
@@ -107,6 +126,18 @@ class BaseRAGExample(ABC):
default=None, default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.", help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
) )
llm_group.add_argument(
"--llm-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs",
)
llm_group.add_argument(
"--llm-api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# AST Chunking parameters # AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters") ast_group = parser.add_argument_group("AST Chunking Parameters")
@@ -205,9 +236,13 @@ class BaseRAGExample(ABC):
if args.llm == "openai": if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o" config["model"] = args.llm_model or "gpt-4o"
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
resolved_key = resolve_openai_api_key(args.llm_api_key)
if resolved_key:
config["api_key"] = resolved_key
elif args.llm == "ollama": elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b" config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = args.llm_host config["host"] = resolve_ollama_host(args.llm_host)
elif args.llm == "hf": elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct" config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated": elif args.llm == "simulated":
@@ -223,10 +258,20 @@ class BaseRAGExample(ABC):
print(f"\n[Building Index] Creating {self.name} index...") print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}") print(f"Total text chunks: {len(texts)}")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder( builder = LeannBuilder(
backend_name=args.backend_name, backend_name=args.backend_name,
embedding_model=args.embedding_model, embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode, embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree, graph_degree=args.graph_degree,
complexity=args.build_complexity, complexity=args.build_complexity,
is_compact=not args.no_compact, is_compact=not args.no_compact,

View File

@@ -1,16 +1,38 @@
""" """Unified chunking utilities facade.
Chunking utilities for LEANN RAG applications.
Provides AST-aware and traditional text chunking functionality. This module re-exports the packaged utilities from `leann.chunking_utils` so
that both repo apps (importing `chunking`) and installed wheels share one
single implementation. When running from the repo without installation, it
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
""" """
from .utils import ( import sys
from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS, CODE_EXTENSIONS,
create_ast_chunks, create_ast_chunks,
create_text_chunks, create_text_chunks,
create_traditional_chunks, create_traditional_chunks,
detect_code_files, detect_code_files,
get_language_from_extension, get_language_from_extension,
) )
except Exception: # pragma: no cover - best-effort fallback for dev environment
repo_root = Path(__file__).resolve().parents[2]
leann_src = repo_root / "packages" / "leann-core" / "src"
if leann_src.exists():
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
else:
raise
__all__ = [ __all__ = [
"CODE_EXTENSIONS", "CODE_EXTENSIONS",

View File

@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
if count >= max_count and max_count > 0: if count >= max_count and max_count > 0:
break break
last_visit, url, title, visit_count, typed_count, hidden = row last_visit, url, title, visit_count, typed_count, _hidden = row
# Create document content with metadata embedded in text # Create document content with metadata embedded in text
doc_content = f""" doc_content = f"""

View File

@@ -26,6 +26,21 @@ leann build my-code-index --docs ./src --use-ast-chunking
uv pip install -e "." uv pip install -e "."
``` ```
#### For normal users (PyPI install)
- Use `pip install leann` or `uv pip install leann`.
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
#### For developers (from source, editable)
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
uv sync
```
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
## Best Practices ## Best Practices
### When to Use AST Chunking ### When to Use AST Chunking

View File

@@ -83,6 +83,81 @@ ollama pull nomic-embed-text
</details> </details>
## Local & Remote Inference Endpoints
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint either on the same machine or across the network with a couple of flags or environment variables.
### One-Time Environment Setup
```bash
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
export OPENAI_BASE_URL="http://localhost:1234/v1"
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
```
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
### Passing Hosts Per Command
```bash
# Build an index with a remote embedding server
leann build my-notes \
--docs ./notes \
--embedding-mode openai \
--embedding-model text-embedding-qwen3-embedding-0.6b \
--embedding-api-base http://192.168.1.50:1234/v1 \
--embedding-api-key local-dev-key
# Query using a local LM Studio instance via OpenAI-compatible API
leann ask my-notes \
--llm openai \
--llm-model qwen3-8b \
--api-base http://localhost:1234/v1 \
--api-key local-dev-key
# Query an Ollama instance running on another box
leann ask my-notes \
--llm ollama \
--llm-model qwen3:14b \
--host http://192.168.1.101:11434
```
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
- Configure router or cloud provider port forwarding.
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
**Tip:** If your runtime does not require an API key (many local stacks dont), leave `--api-key` unset. LEANN will skip injecting credentials.
### Python API Usage
You can pass the same configuration from Python:
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
backend_name="hnsw",
embedding_mode="openai",
embedding_model="text-embedding-qwen3-embedding-0.6b",
embedding_options={
"base_url": "http://192.168.1.50:1234/v1",
"api_key": "local-dev-key",
},
)
builder.build_index("./indexes/my-notes", chunks)
```
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Index Selection: Matching Your Scale ## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World) ### HNSW (Hierarchical Navigable Small World)

149
docs/grep_search.md Normal file
View File

@@ -0,0 +1,149 @@
# LEANN Grep Search Usage Guide
## Overview
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
## Basic Usage
### Simple Grep Search
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("your_index_path")
# Exact text search
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:100]}...")
print("-" * 40)
```
### Comparison: Semantic vs Grep Search
```python
# Semantic search - finds conceptually similar content
semantic_results = searcher.search("machine learning algorithms", top_k=3)
# Grep search - finds exact text matches
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
```
## When to Use Grep Search
### Use Cases
- **Code Search**: Finding specific function definitions, class names, or variable references
- **Error Debugging**: Locating exact error messages or stack traces
- **Documentation**: Finding specific API endpoints or exact terminology
### Examples
```python
# Find function definitions
functions = searcher.search("def __init__", use_grep=True)
# Find import statements
imports = searcher.search("from sklearn import", use_grep=True)
# Find specific error types
errors = searcher.search("FileNotFoundError", use_grep=True)
# Find TODO comments
todos = searcher.search("TODO:", use_grep=True)
# Find configuration entries
configs = searcher.search("server_port=", use_grep=True)
```
## Technical Details
### How It Works
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
3. **Result Processing**: Parses JSON lines and extracts text and metadata
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
### Search Process
```
Query: "def train_model"
grep -i -n "def train_model" documents.leann.passages.jsonl
Parse matching JSON lines
Calculate scores based on term frequency
Return top_k results
```
### Scoring Algorithm
```python
# Term frequency in document
score = text.lower().count(query.lower())
```
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
## Error Handling
### Common Issues
#### Grep Command Not Found
```
RuntimeError: grep command not found. Please install grep or use semantic search.
```
**Solution**: Install grep on your system:
- **Ubuntu/Debian**: `sudo apt-get install grep`
- **macOS**: grep is pre-installed
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
#### No Results Found
```python
# Check if your query exists in the raw data
results = searcher.search("your_query", use_grep=True)
if not results:
print("No exact matches found. Try:")
print("1. Check spelling and case")
print("2. Use partial terms")
print("3. Switch to semantic search")
```
## Complete Example
```python
#!/usr/bin/env python3
"""
Grep Search Example
Demonstrates grep search for exact text matching.
"""
from leann.api import LeannSearcher
def demonstrate_grep_search():
# Initialize searcher
searcher = LeannSearcher("my_index")
print("=== Function Search ===")
functions = searcher.search("def __init__", use_grep=True, top_k=5)
for i, result in enumerate(functions, 1):
print(f"{i}. Score: {result.score}")
print(f" Preview: {result.text[:60]}...")
print()
print("=== Error Search ===")
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
for result in errors:
print(f"Content: {result.text.strip()}")
print("-" * 40)
if __name__ == "__main__":
demonstrate_grep_search()
```

0
examples/__init__.py Normal file
View File

View File

@@ -0,0 +1,404 @@
"""Dynamic HNSW update demo without compact storage.
This script reproduces the minimal scenario we used while debugging on-the-fly
recompute:
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
2. Print the top results with `recompute_embeddings=True`.
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
4. Run the same query again to show the newly inserted passages.
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
ZMQ activity)::
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
uv run -m examples.dynamic_update_no_recompute \
--index-path .leann/examples/leann-demo.leann
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
It issues the query "What's LEANN?" before and after the update to show how the
new passages become immediately searchable. The script uses the
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
freshly added passages are embedded locally just like the initial build.
To make storage comparisons easy, the script can also build a matching
``is_recompute=False`` baseline (enabled by default) and report the index size
delta after the update. Disable the baseline run with
``--skip-compare-no-recompute`` if you only need the recompute flow.
"""
import argparse
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Any
from leann.api import LeannBuilder, LeannSearcher
from leann.registry import register_project_directory
from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path]) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
return [c for c in chunks if isinstance(c, str) and c.strip()]
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
searcher = LeannSearcher(str(index_path))
try:
return searcher.search(
query=query,
top_k=top_k,
recompute_embeddings=recompute_embeddings,
batch_size=16,
)
finally:
searcher.cleanup()
def print_results(title: str, results: Iterable) -> None:
print(f"\n=== {title} ===")
res_list = list(results)
print(f"results count: {len(res_list)}")
print("passages:")
if not res_list:
print(" (no passages returned)")
for res in res_list:
snippet = res.text.replace("\n", " ")[:120]
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def update_index(
index_path: Path,
start_id: int,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
updater = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for offset, passage in enumerate(paragraphs, start=start_id):
updater.add_text(passage, metadata={"id": str(offset)})
updater.update_index(str(index_path))
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
"""Remove leftover index artifacts for a clean rebuild."""
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def index_file_size(index_path: Path) -> int:
"""Return the size of the primary .index file for the given index path."""
index_file = index_path.parent / f"{index_path.stem}.index"
return index_file.stat().st_size if index_file.exists() else 0
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
if not meta_path.exists():
return None
try:
return json.loads(meta_path.read_text())
except json.JSONDecodeError:
return None
def run_workflow(
*,
label: str,
index_path: Path,
initial_paragraphs: list[str],
update_paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
query: str,
top_k: int,
) -> dict[str, Any]:
prefix = f"[{label}] " if label else ""
ensure_index_dir(index_path)
cleanup_index_files(index_path)
print(f"{prefix}Building initial index...")
build_initial_index(
index_path,
initial_paragraphs,
model_name,
embedding_mode,
is_recompute=is_recompute,
)
initial_size = index_file_size(index_path)
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
print(f"\n{prefix}Updating index with additional passages...")
update_index(
index_path,
start_id=len(initial_paragraphs),
paragraphs=update_paragraphs,
model_name=model_name,
embedding_mode=embedding_mode,
is_recompute=is_recompute,
)
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
updated_size = index_file_size(index_path)
return {
"initial_size": initial_size,
"updated_size": updated_size,
"delta": updated_size - initial_size,
"before_results": before_results,
"after_results": after_results,
"metadata": load_metadata_snapshot(index_path),
}
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--initial-files",
type=Path,
nargs="+",
default=DEFAULT_INITIAL_FILES,
help="Initial document files (PDF/TXT) used to build the base index",
)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/examples/leann-demo.leann"),
help="Destination index path (default: .leann/examples/leann-demo.leann)",
)
parser.add_argument(
"--initial-count",
type=int,
default=8,
help="Number of chunks to use from the initial documents (default: 8)",
)
parser.add_argument(
"--update-files",
type=Path,
nargs="*",
default=DEFAULT_UPDATE_FILES,
help="Additional documents to add during update (PDF/TXT)",
)
parser.add_argument(
"--update-count",
type=int,
default=4,
help="Number of chunks to append from update documents (default: 4)",
)
parser.add_argument(
"--update-text",
type=str,
default=(
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
"on demand to keep disk usage minimal."
),
help="Fallback text to append if --update-files is omitted",
)
parser.add_argument(
"--top-k",
type=int,
default=4,
help="Number of results to show for each search (default: 4)",
)
parser.add_argument(
"--query",
type=str,
default=DEFAULT_QUERY,
help="Query to run before/after the update",
)
parser.add_argument(
"--embedding-model",
type=str,
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(
"--compare-no-recompute",
dest="compare_no_recompute",
action="store_true",
help="Also run a baseline with is_recompute=False and report its index growth.",
)
parser.add_argument(
"--skip-compare-no-recompute",
dest="compare_no_recompute",
action="store_false",
help="Skip building the no-recompute baseline.",
)
parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args()
ensure_index_dir(args.index_path)
register_project_directory(REPO_ROOT)
initial_chunks = load_chunks_from_files(list(args.initial_files))
if not initial_chunks:
raise ValueError("No text chunks extracted from the initial files.")
initial = initial_chunks[: args.initial_count]
if not initial:
raise ValueError("Initial chunk set is empty after applying --initial-count.")
if args.update_files:
update_chunks = load_chunks_from_files(list(args.update_files))
if not update_chunks:
raise ValueError("No text chunks extracted from the update files.")
to_add = update_chunks[: args.update_count]
else:
if not args.update_text:
raise ValueError("Provide --update-files or --update-text for the update step.")
to_add = [args.update_text]
if not to_add:
raise ValueError("Update chunk set is empty after applying --update-count.")
recompute_stats = run_workflow(
label="recompute",
index_path=args.index_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=True,
query=args.query,
top_k=args.top_k,
)
print_results("initial search", recompute_stats["before_results"])
print_results("after update", recompute_stats["after_results"])
print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})"
)
if recompute_stats["metadata"]:
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if args.compare_no_recompute:
baseline_path = (
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
)
baseline_stats = run_workflow(
label="no-recompute",
index_path=baseline_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=False,
query=args.query,
top_k=args.top_k,
)
print(
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
f"{baseline_stats['delta']})"
)
after_texts = [res.text for res in recompute_stats["after_results"]]
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
if after_texts == baseline_after_texts:
print(
"[no-recompute] Search results match recompute baseline; see above for the shared output."
)
else:
print("[no-recompute] WARNING: search results differ from recompute baseline.")
if baseline_stats["metadata"]:
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[no-recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,35 @@
"""
Grep Search Example
Shows how to use grep-based text search instead of semantic search.
Useful when you need exact text matches rather than meaning-based results.
"""
from leann import LeannSearcher
# Load your index
searcher = LeannSearcher("my-documents.leann")
# Regular semantic search
print("=== Semantic Search ===")
results = searcher.search("machine learning algorithms", top_k=3)
for result in results:
print(f"Score: {result.score:.3f}")
print(f"Text: {result.text[:80]}...")
print()
# Grep-based search for exact text matches
print("=== Grep Search ===")
results = searcher.search("def train_model", top_k=3, use_grep=True)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:80]}...")
print()
# Find specific error messages
error_results = searcher.search("FileNotFoundError", use_grep=True)
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
# Search for function definitions
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
print(f"Found {len(func_results)} class definitions")

28
llms.txt Normal file
View File

@@ -0,0 +1,28 @@
# llms.txt — LEANN MCP and Agent Integration
product: LEANN
homepage: https://github.com/yichuan-w/LEANN
contact: https://github.com/yichuan-w/LEANN/issues
# Installation
install: uv tool install leann-core --with leann
# MCP Server Entry Point
mcp.server: leann_mcp
mcp.protocol_version: 2024-11-05
# Tools
mcp.tools: leann_list, leann_search
mcp.tool.leann_list.description: List available LEANN indexes
mcp.tool.leann_list.input: {}
mcp.tool.leann_search.description: Semantic search across a named LEANN index
mcp.tool.leann_search.input.index_name: string, required
mcp.tool.leann_search.input.query: string, required
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
# Notes
note: Build indexes with `leann build <name> --docs <files...>` before searching.
example.add: claude mcp add --scope user leann-server -- leann_mcp
example.verify: claude mcp list | cat

View File

@@ -10,7 +10,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Optional
import numpy as np import numpy as np
import zmq import zmq
@@ -32,6 +32,16 @@ if not logger.handlers:
logger.propagate = False logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_diskann_embedding_server( def create_diskann_embedding_server(
passages_file: Optional[str] = None, passages_file: Optional[str] = None,
zmq_port: int = 5555, zmq_port: int = 5555,
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5 logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
# Process embeddings using unified computation # Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info( logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
continue continue
# Process the request # Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(f"Computed embeddings shape: {embeddings.shape}") logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation # Validation

View File

@@ -1,11 +1,11 @@
[build-system] [build-system]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"] requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
build-backend = "scikit_build_core.build" build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-diskann" name = "leann-backend-diskann"
version = "0.3.3" version = "0.3.4"
dependencies = ["leann-core==0.3.3", "numpy", "protobuf>=3.19.0"] dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build] [tool.scikit-build]
# Key: simplified CMake path # Key: simplified CMake path

View File

@@ -5,6 +5,8 @@ import os
import struct import struct
import sys import sys
import time import time
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np import numpy as np
@@ -237,6 +239,288 @@ def write_compact_format(
f_out.write(storage_data) f_out.write(storage_data)
@dataclass
class HNSWComponents:
original_hnsw_data: dict[str, Any]
assign_probas_np: np.ndarray
cum_nneighbor_per_level_np: np.ndarray
levels_np: np.ndarray
is_compact: bool
compact_level_ptr: Optional[np.ndarray] = None
compact_node_offsets_np: Optional[np.ndarray] = None
compact_neighbors_data: Optional[list[int]] = None
offsets_np: Optional[np.ndarray] = None
neighbors_np: Optional[np.ndarray] = None
storage_fourcc: int = NULL_INDEX_FOURCC
storage_data: bytes = b""
def _read_hnsw_structure(f) -> HNSWComponents:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
raise ValueError(
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
)
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f, "<i")
original_hnsw_data["ntotal"] = read_struct(f, "<q")
original_hnsw_data["dummy1"] = read_struct(f, "<q")
original_hnsw_data["dummy2"] = read_struct(f, "<q")
original_hnsw_data["is_trained"] = read_struct(f, "?")
original_hnsw_data["metric_type"] = read_struct(f, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
assign_probas_np = read_numpy_vector(f, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
levels_np = read_numpy_vector(f, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = read_struct(f, "<I")
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
storage_data = f.read()
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=True,
compact_level_ptr=compact_level_ptr,
compact_node_offsets_np=compact_node_offsets_np,
compact_neighbors_data=compact_neighbors_data,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
# Non-compact case
f.seek(pos_before_compact)
pos_before_probe = f.tell()
try:
suspected_flag = read_struct(f, "<B")
if suspected_flag != 0x00:
f.seek(pos_before_probe)
except EOFError:
f.seek(pos_before_probe)
offsets_np = read_numpy_vector(f, np.uint64, "Q")
neighbors_np = read_numpy_vector(f, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
try:
storage_fourcc = read_struct(f, "<I")
storage_data = f.read()
except EOFError:
storage_fourcc = NULL_INDEX_FOURCC
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=False,
offsets_np=offsets_np,
neighbors_np=neighbors_np,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
with open(path, "rb") as f:
return _read_hnsw_structure(f)
def write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
storage_fourcc,
storage_data,
):
"""Write non-compact HNSW data in original FAISS order."""
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, "i")
write_numpy_vector(f_out, offsets_np, "Q")
write_numpy_vector(f_out, neighbors_np, "i")
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
f_out.write(struct.pack("<I", storage_fourcc))
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data)
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
"""Rewrite an HNSW index while dropping the embedded storage section."""
start_time = time.time()
try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
levels_np = read_numpy_vector(f_in, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f_in.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f_in, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = read_struct(f_in, "<I")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
_storage_data = f_in.read()
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
NULL_INDEX_FOURCC,
b"",
)
else:
f_in.seek(pos_before_compact)
pos_before_probe = f_in.tell()
try:
suspected_flag = read_struct(f_in, "<B")
if suspected_flag != 0x00:
f_in.seek(pos_before_probe)
except EOFError:
f_in.seek(pos_before_probe)
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = None
_storage_data = b""
try:
_storage_fourcc = read_struct(f_in, "<I")
_storage_data = f_in.read()
except EOFError:
_storage_fourcc = NULL_INDEX_FOURCC
write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
NULL_INDEX_FOURCC,
b"",
)
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
return True
except Exception as exc:
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
return False
# --- Main Conversion Logic --- # --- Main Conversion Logic ---
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
pass pass
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
"""Convenience wrapper to prune embeddings in-place."""
temp_path = f"{index_filename}.prune.tmp"
success = prune_hnsw_embeddings(index_filename, temp_path)
if success:
try:
os.replace(temp_path, index_filename)
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to replace original index with pruned version: {exc}")
try:
os.remove(temp_path)
except OSError:
pass
return False
else:
try:
os.remove(temp_path)
except OSError:
pass
return success
# --- Script Execution --- # --- Script Execution ---
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@@ -14,7 +14,7 @@ from leann.interface import (
from leann.registry import register_backend from leann.registry import register_backend
from leann.searcher_base import BaseSearcher from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -92,6 +92,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_compact: if self.is_compact:
self._convert_to_csr(index_file) self._convert_to_csr(index_file)
elif self.is_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
def _convert_to_csr(self, index_file: Path): def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format""" """Convert built index to CSR format"""
@@ -133,10 +135,10 @@ class HNSWSearcher(BaseSearcher):
if metric_enum is None: if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.") raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.is_compact, self.is_pruned = ( backend_meta_kwargs = self.meta.get("backend_kwargs", {})
self.meta.get("is_compact", True), self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
self.meta.get("is_pruned", True), default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
) self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
index_file = self.index_dir / f"{self.index_path.stem}.index" index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists(): if not index_file.exists():

View File

@@ -10,7 +10,7 @@ import sys
import threading import threading
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Optional
import msgpack import msgpack
import numpy as np import numpy as np
@@ -24,13 +24,35 @@ logger = logging.getLogger(__name__)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING) log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level) logger.setLevel(log_level)
# Ensure we have a handler if none exists # Ensure we have handlers if none exist
if not logger.handlers: if not logger.handlers:
handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(stream_handler)
logger.propagate = False
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
if log_path:
try:
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
except Exception as exc: # pragma: no cover - best effort logging
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_hnsw_embedding_server( def create_hnsw_embedding_server(
@@ -138,7 +160,12 @@ def create_hnsw_embedding_server(
): ):
last_request_type = "text" last_request_type = "text"
last_request_length = len(request) last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode) embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist())) rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time() e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s") logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
@@ -187,7 +214,10 @@ def create_hnsw_embedding_server(
if texts: if texts:
try: try:
embeddings = compute_embeddings( embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
) )
logger.info( logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
@@ -252,7 +282,12 @@ def create_hnsw_embedding_server(
if texts: if texts:
try: try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode) embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info( logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}" f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
) )

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project] [project]
name = "leann-backend-hnsw" name = "leann-backend-hnsw"
version = "0.3.3" version = "0.3.4"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit." description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [ dependencies = [
"leann-core==0.3.3", "leann-core==0.3.4",
"numpy", "numpy",
"pyzmq>=23.0.0", "pyzmq>=23.0.0",
"msgpack>=1.0.0", "msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann-core" name = "leann-core"
version = "0.3.3" version = "0.3.4"
description = "Core API and plugin system for LEANN" description = "Core API and plugin system for LEANN"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -6,6 +6,8 @@ with the correct, original embedding logic from the user's reference code.
import json import json
import logging import logging
import pickle import pickle
import re
import subprocess
import time import time
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -13,6 +15,7 @@ from pathlib import Path
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
import numpy as np import numpy as np
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface from leann.interface import LeannBackendSearcherInterface
@@ -36,6 +39,7 @@ def compute_embeddings(
use_server: bool = True, use_server: bool = True,
port: Optional[int] = None, port: Optional[int] = None,
is_build=False, is_build=False,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Computes embeddings using different backends. Computes embeddings using different backends.
@@ -69,6 +73,7 @@ def compute_embeddings(
model_name, model_name,
mode=mode, mode=mode,
is_build=is_build, is_build=is_build,
provider_options=provider_options,
) )
@@ -275,6 +280,7 @@ class LeannBuilder:
embedding_model: str = "facebook/contriever", embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None, dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
embedding_options: Optional[dict[str, Any]] = None,
**backend_kwargs, **backend_kwargs,
): ):
self.backend_name = backend_name self.backend_name = backend_name
@@ -297,6 +303,7 @@ class LeannBuilder:
self.embedding_model = embedding_model self.embedding_model = embedding_model
self.dimensions = dimensions self.dimensions = dimensions
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.embedding_options = embedding_options or {}
# Check if we need to use cosine distance for normalized embeddings # Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = { normalized_embeddings_models = {
@@ -404,6 +411,7 @@ class LeannBuilder:
self.embedding_model, self.embedding_model,
self.embedding_mode, self.embedding_mode,
use_server=False, use_server=False,
provider_options=self.embedding_options,
)[0] )[0]
) )
path = Path(index_path) path = Path(index_path)
@@ -443,6 +451,7 @@ class LeannBuilder:
self.embedding_mode, self.embedding_mode,
use_server=False, use_server=False,
is_build=True, is_build=True,
provider_options=self.embedding_options,
) )
string_ids = [chunk["id"] for chunk in self.chunks] string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions} current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
@@ -469,14 +478,15 @@ class LeannBuilder:
], ],
} }
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend # Add storage status flags for HNSW backend
if self.backend_name == "hnsw": if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True) is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True) is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = ( meta_data["is_pruned"] = bool(is_recompute)
is_compact and is_recompute
) # Pruned only if compact and recompute
with open(leann_meta_path, "w", encoding="utf-8") as f: with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
@@ -591,18 +601,166 @@ class LeannBuilder:
"embeddings_source": str(embeddings_file), "embeddings_source": str(embeddings_file),
} }
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend # Add storage status flags for HNSW backend
if self.backend_name == "hnsw": if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True) is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True) is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute meta_data["is_pruned"] = bool(is_recompute)
with open(leann_meta_path, "w", encoding="utf-8") as f: with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2) json.dump(meta_data, f, indent=2)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}") logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
def update_index(self, index_path: str):
"""Append new passages and vectors to an existing HNSW index."""
if not self.chunks:
raise ValueError("No new chunks provided for update.")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_prefix = path.stem
meta_path = index_dir / f"{index_name}.meta.json"
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
index_file = index_dir / f"{index_prefix}.index"
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
backend_name = meta.get("backend_name")
if backend_name != self.backend_name:
raise ValueError(
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
)
meta_backend_kwargs = meta.get("backend_kwargs", {})
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
if index_is_compact:
raise ValueError(
"Compact HNSW indices do not support in-place updates. Rebuild required."
)
distance_metric = meta_backend_kwargs.get(
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
).lower()
needs_recompute = bool(
meta.get("is_pruned")
or meta_backend_kwargs.get("is_recompute")
or self.backend_kwargs.get("is_recompute")
)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in self.chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
embedding_dim = embeddings.shape[1]
expected_dim = meta.get("dimensions")
if expected_dim is not None and expected_dim != embedding_dim:
raise ValueError(
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
)
from leann_backend_hnsw import faiss # type: ignore
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
if index.d != embedding_dim:
raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(
"Appended %d passages to index '%s'. New total: %d",
len(valid_chunks),
index_path,
len(offset_map),
)
self.chunks.clear()
if needs_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
class LeannSearcher: class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs): def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
@@ -626,6 +784,7 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"] self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format # Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta_data.get("embedding_options", {})
# Delegate portability handling to PassageManager # Delegate portability handling to PassageManager
self.passage_manager = PassageManager( self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
@@ -637,6 +796,8 @@ class LeannSearcher:
raise ValueError(f"Backend '{backend_name}' not found.") raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs} final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher( self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs index_path, **final_kwargs
) )
@@ -653,6 +814,7 @@ class LeannSearcher:
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0, batch_size: int = 0,
use_grep: bool = False,
**kwargs, **kwargs,
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
@@ -679,6 +841,10 @@ class LeannSearcher:
Returns: Returns:
List of SearchResult objects with text, metadata, and similarity scores List of SearchResult objects with text, metadata, and similarity scores
""" """
# Handle grep search
if use_grep:
return self._grep_search(query, top_k)
logger.info("🔍 LeannSearcher.search() called:") logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'") logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}") logger.info(f" Top_k: {top_k}")
@@ -795,9 +961,96 @@ class LeannSearcher:
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}") logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results return enriched_results
def _find_jsonl_file(self) -> Optional[str]:
"""Find the .jsonl file containing raw passages for grep search"""
index_path = Path(self.meta_path_str).parent
potential_files = [
index_path / "documents.leann.passages.jsonl",
index_path.parent / "documents.leann.passages.jsonl",
]
for file_path in potential_files:
if file_path.exists():
return str(file_path)
return None
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Perform grep-based search on raw passages"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl passages file found for grep search")
try:
cmd = ["grep", "-i", "-n", query, jsonl_file]
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 1:
return []
elif result.returncode != 0:
raise RuntimeError(f"Grep failed: {result.stderr}")
matches = []
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split(":", 1)
if len(parts) != 2:
continue
try:
data = json.loads(parts[1])
text = data.get("text", "")
score = text.lower().count(query.lower())
matches.append(
SearchResult(
id=data.get("id", parts[0]),
text=text,
metadata=data.get("metadata", {}),
score=float(score),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
except FileNotFoundError:
raise RuntimeError(
"grep command not found. Please install grep or use semantic search."
)
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Fallback regex search"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl file found")
pattern = re.compile(re.escape(query), re.IGNORECASE)
matches = []
with open(jsonl_file, encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if pattern.search(line):
try:
data = json.loads(line.strip())
matches.append(
SearchResult(
id=data.get("id", str(line_num)),
text=data.get("text", ""),
metadata=data.get("metadata", {}),
score=float(len(pattern.findall(data.get("text", "")))),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
def cleanup(self): def cleanup(self):
"""Explicitly cleanup embedding server resources. """Explicitly cleanup embedding server resources.
This method should be called after you're done using the searcher, This method should be called after you're done using the searcher,
especially in test environments or batch processing scenarios. especially in test environments or batch processing scenarios.
""" """
@@ -853,6 +1106,7 @@ class LeannChat:
expected_zmq_port: int = 5557, expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0, batch_size: int = 0,
use_grep: bool = False,
**search_kwargs, **search_kwargs,
): ):
if llm_kwargs is None: if llm_kwargs is None:

View File

@@ -12,6 +12,8 @@ from typing import Any, Optional
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest( def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434" model_name: str, llm_type: str, host: Optional[str] = None
) -> Optional[str]: ) -> Optional[str]:
"""Validate model name and provide suggestions if invalid""" """Validate model name and provide suggestions if invalid"""
if llm_type == "ollama": if llm_type == "ollama":
available_models = check_ollama_models(host) resolved_host = resolve_ollama_host(host)
available_models = check_ollama_models(resolved_host)
if available_models and model_name not in available_models: if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation." error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface): class OllamaChat(LLMInterface):
"""LLM interface for Ollama models.""" """LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"): def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
self.model = model self.model = model
self.host = host self.host = resolve_ollama_host(host)
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'") logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
try: try:
import requests import requests
# Check if the Ollama server is responsive # Check if the Ollama server is responsive
if host: if self.host:
requests.get(host) requests.get(self.host)
# Pre-check model availability with helpful suggestions # Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host) model_error = validate_model_and_suggest(model, "ollama", self.host)
if model_error: if model_error:
raise ValueError(model_error) raise ValueError(model_error)
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'." "The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.") logger.error(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
raise ConnectionError( raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running." f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
) )
def ask(self, prompt: str, **kwargs) -> str: def ask(self, prompt: str, **kwargs) -> str:
@@ -737,21 +742,31 @@ class GeminiChat(LLMInterface):
class OpenAIChat(LLMInterface): class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models.""" """LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None): def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.base_url = resolve_openai_base_url(base_url)
self.api_key = resolve_openai_api_key(api_key)
if not self.api_key: if not self.api_key:
raise ValueError( raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
) )
logger.info(f"Initializing OpenAI Chat with model='{model}'") logger.info(
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try: try:
import openai import openai
self.client = openai.OpenAI(api_key=self.api_key) self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'." "The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama": if llm_type == "ollama":
return OllamaChat( return OllamaChat(
model=model or "llama3:8b", model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"), host=llm_config.get("host"),
) )
elif llm_type == "hf": elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat") return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai": elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key")) return OpenAIChat(
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini": elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key")) return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated": elif llm_type == "simulated":

View File

@@ -1,6 +1,6 @@
""" """
Enhanced chunking utilities with AST-aware code chunking support. Enhanced chunking utilities with AST-aware code chunking support.
Provides unified interface for both traditional and AST-based text chunking. Packaged within leann-core so installed wheels can import it reliably.
""" """
import logging import logging
@@ -22,30 +22,9 @@ CODE_EXTENSIONS = {
".jsx": "typescript", ".jsx": "typescript",
} }
# Default chunk parameters for different content types
DEFAULT_CHUNK_PARAMS = {
"code": {
"max_chunk_size": 512,
"chunk_overlap": 64,
},
"text": {
"chunk_size": 256,
"chunk_overlap": 128,
},
}
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]: def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
""" """Separate documents into code files and regular text files."""
Separate documents into code files and regular text files.
Args:
documents: List of LlamaIndex Document objects
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
Returns:
Tuple of (code_documents, text_documents)
"""
if code_extensions is None: if code_extensions is None:
code_extensions = CODE_EXTENSIONS code_extensions = CODE_EXTENSIONS
@@ -53,16 +32,10 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
text_docs = [] text_docs = []
for doc in documents: for doc in documents:
# Get file path from metadata file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
file_path = doc.metadata.get("file_path", "")
if not file_path:
# Fallback to file_name
file_path = doc.metadata.get("file_name", "")
if file_path: if file_path:
file_ext = Path(file_path).suffix.lower() file_ext = Path(file_path).suffix.lower()
if file_ext in code_extensions: if file_ext in code_extensions:
# Add language info to metadata
doc.metadata["language"] = code_extensions[file_ext] doc.metadata["language"] = code_extensions[file_ext]
doc.metadata["is_code"] = True doc.metadata["is_code"] = True
code_docs.append(doc) code_docs.append(doc)
@@ -70,7 +43,6 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
doc.metadata["is_code"] = False doc.metadata["is_code"] = False
text_docs.append(doc) text_docs.append(doc)
else: else:
# If no file path, treat as text
doc.metadata["is_code"] = False doc.metadata["is_code"] = False
text_docs.append(doc) text_docs.append(doc)
@@ -79,7 +51,7 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
def get_language_from_extension(file_path: str) -> Optional[str]: def get_language_from_extension(file_path: str) -> Optional[str]:
"""Get the programming language from file extension.""" """Return language string from a filename/extension using CODE_EXTENSIONS."""
ext = Path(file_path).suffix.lower() ext = Path(file_path).suffix.lower()
return CODE_EXTENSIONS.get(ext) return CODE_EXTENSIONS.get(ext)
@@ -90,40 +62,26 @@ def create_ast_chunks(
chunk_overlap: int = 64, chunk_overlap: int = 64,
metadata_template: str = "default", metadata_template: str = "default",
) -> list[str]: ) -> list[str]:
""" """Create AST-aware chunks from code documents using astchunk.
Create AST-aware chunks from code documents using astchunk.
Args: Falls back to traditional chunking if astchunk is unavailable.
documents: List of code documents
max_chunk_size: Maximum characters per chunk
chunk_overlap: Number of AST nodes to overlap between chunks
metadata_template: Template for chunk metadata
Returns:
List of text chunks with preserved code structure
""" """
try: try:
from astchunk import ASTChunkBuilder from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e: except ImportError as e:
logger.error(f"astchunk not available: {e}") logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files") logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap) return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
all_chunks = [] all_chunks = []
for doc in documents: for doc in documents:
# Get language from metadata (set by detect_code_files)
language = doc.metadata.get("language") language = doc.metadata.get("language")
if not language: if not language:
logger.warning( logger.warning("No language detected; falling back to traditional chunking")
"No language detected for document, falling back to traditional chunking" all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
)
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
all_chunks.extend(traditional_chunks)
continue continue
try: try:
# Configure astchunk
configs = { configs = {
"max_chunk_size": max_chunk_size, "max_chunk_size": max_chunk_size,
"language": language, "language": language,
@@ -131,7 +89,6 @@ def create_ast_chunks(
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0, "chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
} }
# Add repository-level metadata if available
repo_metadata = { repo_metadata = {
"file_path": doc.metadata.get("file_path", ""), "file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""), "file_name": doc.metadata.get("file_name", ""),
@@ -140,17 +97,13 @@ def create_ast_chunks(
} }
configs["repo_level_metadata"] = repo_metadata configs["repo_level_metadata"] = repo_metadata
# Create chunk builder and process
chunk_builder = ASTChunkBuilder(**configs) chunk_builder = ASTChunkBuilder(**configs)
code_content = doc.get_content() code_content = doc.get_content()
if not code_content or not code_content.strip(): if not code_content or not code_content.strip():
logger.warning("Empty code content, skipping") logger.warning("Empty code content, skipping")
continue continue
chunks = chunk_builder.chunkify(code_content) chunks = chunk_builder.chunkify(code_content)
# Extract text content from chunks
for chunk in chunks: for chunk in chunks:
if hasattr(chunk, "text"): if hasattr(chunk, "text"):
chunk_text = chunk.text chunk_text = chunk.text
@@ -159,7 +112,6 @@ def create_ast_chunks(
elif isinstance(chunk, str): elif isinstance(chunk, str):
chunk_text = chunk chunk_text = chunk
else: else:
# Try to convert to string
chunk_text = str(chunk) chunk_text = str(chunk)
if chunk_text and chunk_text.strip(): if chunk_text and chunk_text.strip():
@@ -168,12 +120,10 @@ def create_ast_chunks(
logger.info( logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}" f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
) )
except Exception as e: except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}") logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking") logger.info("Falling back to traditional chunking")
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap) all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
all_chunks.extend(traditional_chunks)
return all_chunks return all_chunks
@@ -181,23 +131,10 @@ def create_ast_chunks(
def create_traditional_chunks( def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128 documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]: ) -> list[str]:
""" """Create traditional text chunks using LlamaIndex SentenceSplitter."""
Create traditional text chunks using LlamaIndex SentenceSplitter.
Args:
documents: List of documents to chunk
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
Returns:
List of text chunks
"""
# Handle invalid chunk_size values
if chunk_size <= 0: if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256") logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256 chunk_size = 256
# Ensure chunk_overlap is not negative and not larger than chunk_size
if chunk_overlap < 0: if chunk_overlap < 0:
chunk_overlap = 0 chunk_overlap = 0
if chunk_overlap >= chunk_size: if chunk_overlap >= chunk_size:
@@ -215,12 +152,9 @@ def create_traditional_chunks(
try: try:
nodes = node_parser.get_nodes_from_documents([doc]) nodes = node_parser.get_nodes_from_documents([doc])
if nodes: if nodes:
chunk_texts = [node.get_content() for node in nodes] all_texts.extend(node.get_content() for node in nodes)
all_texts.extend(chunk_texts)
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
except Exception as e: except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}") logger.error(f"Traditional chunking failed for document: {e}")
# As last resort, add the raw content
content = doc.get_content() content = doc.get_content()
if content and content.strip(): if content and content.strip():
all_texts.append(content.strip()) all_texts.append(content.strip())
@@ -238,32 +172,13 @@ def create_text_chunks(
code_file_extensions: Optional[list[str]] = None, code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True, ast_fallback_traditional: bool = True,
) -> list[str]: ) -> list[str]:
""" """Create text chunks from documents with optional AST support for code files."""
Create text chunks from documents with optional AST support for code files.
Args:
documents: List of LlamaIndex Document objects
chunk_size: Size for traditional text chunks
chunk_overlap: Overlap for traditional text chunks
use_ast_chunking: Whether to use AST chunking for code files
ast_chunk_size: Size for AST chunks
ast_chunk_overlap: Overlap for AST chunks
code_file_extensions: Custom list of code file extensions
ast_fallback_traditional: Fall back to traditional chunking on AST errors
Returns:
List of text chunks
"""
if not documents: if not documents:
logger.warning("No documents provided for chunking") logger.warning("No documents provided for chunking")
return [] return []
# Create a local copy of supported extensions for this function call
local_code_extensions = CODE_EXTENSIONS.copy() local_code_extensions = CODE_EXTENSIONS.copy()
# Update supported extensions if provided
if code_file_extensions: if code_file_extensions:
# Map extensions to languages (simplified mapping)
ext_mapping = { ext_mapping = {
".py": "python", ".py": "python",
".java": "java", ".java": "java",
@@ -273,47 +188,32 @@ def create_text_chunks(
} }
for ext in code_file_extensions: for ext in code_file_extensions:
if ext.lower() not in local_code_extensions: if ext.lower() not in local_code_extensions:
# Try to guess language from extension
if ext.lower() in ext_mapping: if ext.lower() in ext_mapping:
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()] local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
else: else:
logger.warning(f"Unsupported extension {ext}, will use traditional chunking") logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
all_chunks = [] all_chunks = []
if use_ast_chunking: if use_ast_chunking:
# Separate code and text documents using local extensions
code_docs, text_docs = detect_code_files(documents, local_code_extensions) code_docs, text_docs = detect_code_files(documents, local_code_extensions)
# Process code files with AST chunking
if code_docs: if code_docs:
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
try: try:
ast_chunks = create_ast_chunks( all_chunks.extend(
create_ast_chunks(
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
) )
all_chunks.extend(ast_chunks) )
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
except Exception as e: except Exception as e:
logger.error(f"AST chunking failed: {e}") logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional: if ast_fallback_traditional:
logger.info("Falling back to traditional chunking for code files") all_chunks.extend(
traditional_code_chunks = create_traditional_chunks( create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
code_docs, chunk_size, chunk_overlap
) )
all_chunks.extend(traditional_code_chunks)
else: else:
raise raise
# Process text files with traditional chunking
if text_docs: if text_docs:
logger.info(f"Processing {len(text_docs)} text files with traditional chunking") all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
all_chunks.extend(text_chunks)
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
else: else:
# Use traditional chunking for all files
logger.info(f"Processing {len(documents)} documents with traditional chunking")
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap) all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}") logger.info(f"Total chunks created: {len(all_chunks)}")

View File

@@ -1,6 +1,5 @@
import argparse import argparse
import asyncio import asyncio
import sys
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Union from typing import Any, Optional, Union
@@ -10,6 +9,7 @@ from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import register_project_directory from .registry import register_project_directory
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
def extract_pdf_text_with_pymupdf(file_path: str) -> str: def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -124,6 +124,24 @@ Examples:
choices=["sentence-transformers", "openai", "mlx", "ollama"], choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)", help="Embedding backend mode (default: sentence-transformers)",
) )
build_parser.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
build_parser.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
build_parser.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
build_parser.add_argument( build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index" "--force", "-f", action="store_true", help="Force rebuild existing index"
) )
@@ -239,6 +257,11 @@ Examples:
# Ask command # Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions") ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name") ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"query",
nargs="?",
help="Question to ask (omit for prompt or when using --interactive)",
)
ask_parser.add_argument( ask_parser.add_argument(
"--llm", "--llm",
type=str, type=str,
@@ -249,7 +272,12 @@ Examples:
ask_parser.add_argument( ask_parser.add_argument(
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)" "--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
) )
ask_parser.add_argument("--host", type=str, default="http://localhost:11434") ask_parser.add_argument(
"--host",
type=str,
default=None,
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
ask_parser.add_argument( ask_parser.add_argument(
"--interactive", "-i", action="store_true", help="Interactive chat mode" "--interactive", "-i", action="store_true", help="Interactive chat mode"
) )
@@ -278,6 +306,18 @@ Examples:
default=None, default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.", help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
) )
ask_parser.add_argument(
"--api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
)
ask_parser.add_argument(
"--api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# List command # List command
subparsers.add_parser("list", help="List all indexes") subparsers.add_parser("list", help="List all indexes")
@@ -1216,13 +1256,8 @@ Examples:
if use_ast: if use_ast:
print("🧠 Using AST-aware chunking for code files") print("🧠 Using AST-aware chunking for code files")
try: try:
# Import enhanced chunking utilities # Import enhanced chunking utilities from packaged module
# Add apps directory to path to import chunking utilities from .chunking_utils import create_text_chunks
apps_dir = Path(__file__).parent.parent.parent.parent.parent / "apps"
if apps_dir.exists():
sys.path.insert(0, str(apps_dir))
from chunking import create_text_chunks
# Use enhanced chunking with AST support # Use enhanced chunking with AST support
all_texts = create_text_chunks( all_texts = create_text_chunks(
@@ -1237,7 +1272,9 @@ Examples:
) )
except ImportError as e: except ImportError as e:
print(f"⚠️ AST chunking not available ({e}), falling back to traditional chunking") print(
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
)
use_ast = False use_ast = False
if not use_ast: if not use_ast:
@@ -1329,10 +1366,20 @@ Examples:
print(f"Building index '{index_name}' with {args.backend} backend...") print(f"Building index '{index_name}' with {args.backend} backend...")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder( builder = LeannBuilder(
backend_name=args.backend, backend_name=args.backend,
embedding_model=args.embedding_model, embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode, embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree, graph_degree=args.graph_degree,
complexity=args.complexity, complexity=args.complexity,
is_compact=args.compact, is_compact=args.compact,
@@ -1480,11 +1527,38 @@ Examples:
llm_config = {"type": args.llm, "model": args.model} llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama": if args.llm == "ollama":
llm_config["host"] = args.host llm_config["host"] = resolve_ollama_host(args.host)
elif args.llm == "openai":
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key:
llm_config["api_key"] = resolved_api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config) chat = LeannChat(index_path=index_path, llm_config=llm_config)
llm_kwargs: dict[str, Any] = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
response = chat.ask(
prompt,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
initial_query = (args.query or "").strip()
if args.interactive: if args.interactive:
if initial_query:
_ask_once(initial_query)
print("LEANN Assistant ready! Type 'quit' to exit") print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40) print("=" * 40)
@@ -1497,41 +1571,14 @@ Examples:
if not user_input: if not user_input:
continue continue
# Prepare LLM kwargs with thinking budget if specified _ask_once(user_input)
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
else: else:
query = input("Enter your question: ").strip() query = initial_query or input("Enter your question: ").strip()
if query: if not query:
# Prepare LLM kwargs with thinking budget if specified print("No question provided. Exiting.")
llm_kwargs = {} return
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask( _ask_once(query)
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
async def run(self, args=None): async def run(self, args=None):
parser = self.create_parser() parser = self.create_parser()

View File

@@ -7,11 +7,13 @@ Preserves all optimization parameters to ensure performance
import logging import logging
import os import os
import time import time
from typing import Any from typing import Any, Optional
import numpy as np import numpy as np
import torch import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Set up logger with proper level # Set up logger with proper level
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper() LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -31,6 +33,7 @@ def compute_embeddings(
adaptive_optimization: bool = True, adaptive_optimization: bool = True,
manual_tokenize: bool = False, manual_tokenize: bool = False,
max_length: int = 512, max_length: int = 512,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Unified embedding computation entry point Unified embedding computation entry point
@@ -46,6 +49,8 @@ def compute_embeddings(
Returns: Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
""" """
provider_options = provider_options or {}
if mode == "sentence-transformers": if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers( return compute_embeddings_sentence_transformers(
texts, texts,
@@ -57,11 +62,21 @@ def compute_embeddings(
max_length=max_length, max_length=max_length,
) )
elif mode == "openai": elif mode == "openai":
return compute_embeddings_openai(texts, model_name) return compute_embeddings_openai(
texts,
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
)
elif mode == "mlx": elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name) return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama": elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build) return compute_embeddings_ollama(
texts,
model_name,
is_build=is_build,
host=provider_options.get("host"),
)
elif mode == "gemini": elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build) return compute_embeddings_gemini(texts, model_name, is_build=is_build)
else: else:
@@ -353,12 +368,15 @@ def compute_embeddings_sentence_transformers(
return embeddings return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray: def compute_embeddings_openai(
texts: list[str],
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode # TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API""" """Compute embeddings using OpenAI API"""
try: try:
import os
import openai import openai
except ImportError as e: except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}") raise ImportError(f"OpenAI package not installed: {e}")
@@ -373,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI." f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
) )
api_key = os.getenv("OPENAI_API_KEY") resolved_base_url = resolve_openai_base_url(base_url)
if not api_key: resolved_api_key = resolve_openai_api_key(api_key)
if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set") raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client # Cache OpenAI client
cache_key = "openai_client" cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache: if cache_key in _model_cache:
client = _model_cache[cache_key] client = _model_cache[cache_key]
else: else:
client = openai.OpenAI(api_key=api_key) client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client _model_cache[cache_key] = client
logger.info("OpenAI client cached") logger.info("OpenAI client cached")
@@ -507,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
def compute_embeddings_ollama( def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434" texts: list[str],
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
) -> np.ndarray: ) -> np.ndarray:
""" """
Compute embeddings using Ollama API with simplified batch processing. Compute embeddings using Ollama API with simplified batch processing.
@@ -518,7 +541,7 @@ def compute_embeddings_ollama(
texts: List of texts to compute embeddings for texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large") model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar) is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434) host: Ollama host URL (defaults to environment or http://localhost:11434)
Returns: Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim) Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -533,17 +556,19 @@ def compute_embeddings_ollama(
if not texts: if not texts:
raise ValueError("Cannot compute embeddings for empty text list") raise ValueError("Cannot compute embeddings for empty text list")
resolved_host = resolve_ollama_host(host)
logger.info( logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'" f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
) )
# Check if Ollama is running # Check if Ollama is running
try: try:
response = requests.get(f"{host}/api/version", timeout=5) response = requests.get(f"{resolved_host}/api/version", timeout=5)
response.raise_for_status() response.raise_for_status()
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
error_msg = ( error_msg = (
f"❌ Could not connect to Ollama at {host}.\n\n" f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
"Please ensure Ollama is running:\n" "Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n" " • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n" " • Windows: Make sure Ollama is running in the system tray\n\n"
@@ -555,7 +580,7 @@ def compute_embeddings_ollama(
# Check if model exists and provide helpful suggestions # Check if model exists and provide helpful suggestions
try: try:
response = requests.get(f"{host}/api/tags", timeout=5) response = requests.get(f"{resolved_host}/api/tags", timeout=5)
response.raise_for_status() response.raise_for_status()
models = response.json() models = response.json()
model_names = [model["name"] for model in models.get("models", [])] model_names = [model["name"] for model in models.get("models", [])]
@@ -618,7 +643,9 @@ def compute_embeddings_ollama(
# Verify the model supports embeddings by testing it # Verify the model supports embeddings by testing it
try: try:
test_response = requests.post( test_response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10 f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
timeout=10,
) )
if test_response.status_code != 200: if test_response.status_code != 200:
error_msg = ( error_msg = (
@@ -665,7 +692,7 @@ def compute_embeddings_ollama(
while retry_count < max_retries: while retry_count < max_retries:
try: try:
response = requests.post( response = requests.post(
f"{host}/api/embeddings", f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text}, json={"model": model_name, "prompt": truncated_text},
timeout=30, timeout=30,
) )

View File

@@ -8,6 +8,8 @@ import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from .settings import encode_provider_options
# Lightweight, self-contained server manager with no cross-process inspection # Lightweight, self-contained server manager with no cross-process inspection
# Set up logging based on environment variable # Set up logging based on environment variable
@@ -82,16 +84,40 @@ class EmbeddingServerManager:
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start the embedding server.""" """Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here # passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None)
config_signature = {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
# If this manager already has a live server, just reuse it # If this manager already has a live server, just reuse it
if self.server_process and self.server_process.poll() is None and self.server_port: if (
self.server_process
and self.server_process.poll() is None
and self.server_port
and self._server_config == config_signature
):
logger.info("Reusing in-process server") logger.info("Reusing in-process server")
return True, self.server_port return True, self.server_port
# Configuration changed, stop existing server before starting a new one
if self.server_process and self.server_process.poll() is None:
logger.info("Existing server configuration differs; restarting embedding server")
self.stop_server()
# For Colab environment, use a different strategy # For Colab environment, use a different strategy
if _is_colab_environment(): if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy") logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs) return self._start_server_colab(
port,
model_name,
embedding_mode,
provider_options=provider_options,
**kwargs,
)
# Always pick a fresh available port # Always pick a fresh available port
try: try:
@@ -101,13 +127,21 @@ class EmbeddingServerManager:
return False, port return False, port
# Start a new server # Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs) return self._start_new_server(
actual_port,
model_name,
embedding_mode,
provider_options=provider_options,
config_signature=config_signature,
**kwargs,
)
def _start_server_colab( def _start_server_colab(
self, self,
port: int, port: int,
model_name: str, model_name: str,
embedding_mode: str = "sentence-transformers", embedding_mode: str = "sentence-transformers",
provider_options: Optional[dict] = None,
**kwargs, **kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start server with Colab-specific configuration.""" """Start server with Colab-specific configuration."""
@@ -125,8 +159,20 @@ class EmbeddingServerManager:
try: try:
# In Colab, we'll use a more direct approach # In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port) self._launch_server_process_colab(
return self._wait_for_server_ready_colab(actual_port) command,
actual_port,
provider_options=provider_options,
)
started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started:
self._server_config = {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e: except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}") logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port return False, actual_port
@@ -134,7 +180,13 @@ class EmbeddingServerManager:
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance # Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
def _start_new_server( def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs self,
port: int,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]: ) -> tuple[bool, int]:
"""Start a new embedding server on the given port.""" """Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...") logger.info(f"Starting embedding server on port {port}...")
@@ -142,8 +194,20 @@ class EmbeddingServerManager:
command = self._build_server_command(port, model_name, embedding_mode, **kwargs) command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try: try:
self._launch_server_process(command, port) self._launch_server_process(
return self._wait_for_server_ready(port) command,
port,
provider_options=provider_options,
)
started, ready_port = self._wait_for_server_ready(port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e: except Exception as e:
logger.error(f"Failed to start embedding server: {e}") logger.error(f"Failed to start embedding server: {e}")
return False, port return False, port
@@ -173,7 +237,12 @@ class EmbeddingServerManager:
return command return command
def _launch_server_process(self, command: list, port: int) -> None: def _launch_server_process(
self,
command: list,
port: int,
provider_options: Optional[dict] = None,
) -> None:
"""Launch the server process.""" """Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}") logger.info(f"Command: {' '.join(command)}")
@@ -193,14 +262,20 @@ class EmbeddingServerManager:
# Start embedding server subprocess # Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}") logger.info(f"Starting server process with command: {' '.join(command)}")
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,
cwd=project_root, cwd=project_root,
stdout=stdout_target, stdout=stdout_target,
stderr=stderr_target, stderr=stderr_target,
env=env,
) )
self.server_port = port self.server_port = port
# Record config for in-process reuse # Record config for in-process reuse (best effort; refined later when ready)
try: try:
self._server_config = { self._server_config = {
"model_name": command[command.index("--model-name") + 1] "model_name": command[command.index("--model-name") + 1]
@@ -212,12 +287,14 @@ class EmbeddingServerManager:
"embedding_mode": command[command.index("--embedding-mode") + 1] "embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command if "--embedding-mode" in command
else "sentence-transformers", else "sentence-transformers",
"provider_options": provider_options or {},
} }
except Exception: except Exception:
self._server_config = { self._server_config = {
"model_name": "", "model_name": "",
"passages_file": "", "passages_file": "",
"embedding_mode": "sentence-transformers", "embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
} }
logger.info(f"Server process started with PID: {self.server_process.pid}") logger.info(f"Server process started with PID: {self.server_process.pid}")
@@ -322,16 +399,27 @@ class EmbeddingServerManager:
# Removed: cross-process adoption no longer supported # Removed: cross-process adoption no longer supported
return return
def _launch_server_process_colab(self, command: list, port: int) -> None: def _launch_server_process_colab(
self,
command: list,
port: int,
provider_options: Optional[dict] = None,
) -> None:
"""Launch the server process with Colab-specific settings.""" """Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}") logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management # In Colab, we need to be more careful about process management
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen( self.server_process = subprocess.Popen(
command, command,
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
env=env,
) )
self.server_port = port self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}") logger.info(f"Colab server process started with PID: {self.server_process.pid}")
@@ -345,6 +433,7 @@ class EmbeddingServerManager:
"model_name": "", "model_name": "",
"passages_file": "", "passages_file": "",
"embedding_mode": "sentence-transformers", "embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
} }
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]: def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:

View File

@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("WARNING: embedding_model not found in meta.json. Recompute will fail.") print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta.get("embedding_options", {})
self.embedding_server_manager = EmbeddingServerManager( self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name, backend_module_name=backend_module_name,
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file, passages_file=passages_source_file,
distance_metric=distance_metric, distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False), enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options,
) )
if not server_started: if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}") raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
from .embedding_compute import compute_embeddings from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers") embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode) return compute_embeddings(
[query],
self.embedding_model,
embedding_mode,
provider_options=self.embedding_options,
)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray: def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server.""" """Compute embeddings using the ZMQ embedding server."""

View File

@@ -0,0 +1,74 @@
"""Runtime configuration helpers for LEANN."""
from __future__ import annotations
import json
import os
from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
def _clean_url(value: str) -> str:
"""Normalize URL strings by stripping trailing slashes."""
return value.rstrip("/") if value else value
def resolve_ollama_host(explicit: str | None = None) -> str:
"""Resolve the Ollama-compatible endpoint to use."""
candidates = (
explicit,
os.getenv("LEANN_LOCAL_LLM_HOST"),
os.getenv("LEANN_OLLAMA_HOST"),
os.getenv("OLLAMA_HOST"),
os.getenv("LOCAL_LLM_ENDPOINT"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OLLAMA_HOST)
def resolve_openai_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for OpenAI-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_OPENAI_BASE_URL"),
os.getenv("OPENAI_BASE_URL"),
os.getenv("LOCAL_OPENAI_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services."""
if explicit:
return explicit
return os.getenv("OPENAI_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes."""
if not options:
return None
try:
return json.dumps(options)
except (TypeError, ValueError):
# Fall back to empty payload if serialization fails
return None

View File

@@ -2,6 +2,8 @@
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code. Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
For agent-facing discovery details, see `llms.txt` in the repository root.
## Prerequisites ## Prerequisites
Install LEANN globally for MCP integration (with default backend): Install LEANN globally for MCP integration (with default backend):

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "leann" name = "leann"
version = "0.3.3" version = "0.3.4"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!" description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md" readme = "README.md"
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@@ -99,6 +99,7 @@ wechat-exporter = "wechat_exporter.main:main"
leann-core = { path = "packages/leann-core", editable = true } leann-core = { path = "packages/leann-core", editable = true }
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true } leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true } leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
astchunk = { path = "packages/astchunk-leann", editable = true }
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py39"

14
tests/test_cli_ask.py Normal file
View File

@@ -0,0 +1,14 @@
from leann.cli import LeannCLI
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
cli = LeannCLI()
parser = cli.create_parser()
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
assert args.command == "ask"
assert args.index_name == "my-docs"
assert args.query == "Where are prompts configured?"

7401
uv.lock generated
View File

File diff suppressed because it is too large Load Diff