Compare commits

..

21 Commits

Author SHA1 Message Date
Andy Lee
47aeb85f82 Allow 'leann ask' to accept a positional question 2025-09-23 15:18:51 -07:00
Andy Lee
db7ba27ff6 feat: Add support for configurable local LLM endpoints (#115)
* feat: support configurable local llm endpoints

* docs
2025-09-23 15:12:13 -07:00
Andy Lee
5f7806e16f Introducing dynamic index update (#108)
* feat: Add GitHub PR and issue templates for better contributor experience

* simplify: Make templates more concise and user-friendly

* fix: enable is_compact=False, is_recompute=True

* feat: update when recompute

* test

* fix: real recompute

* refactor

* fix: compare with no-recompute

* fix: test
2025-09-21 22:56:27 -07:00
yichuan-w
d034e2195b fix build from source in diskann 2025-09-20 19:52:29 +00:00
yichuan520030910320
43894ff605 update submodule 2025-09-19 17:03:55 -07:00
yichuan520030910320
10311cc611 change the submodule for easy pull 2025-09-19 17:02:09 -07:00
Andy Lee
ad0d2faabc feat: Add GitHub PR and issue templates (#105)
* feat: Add GitHub PR and issue templates for better contributor experience

* simplify: Make templates more concise and user-friendly
2025-09-19 13:51:36 -07:00
Andy Lee
e93c0dec6f [Fix] Enable AST chunking when installed (package chunking utils) (#101)
* fix(core): package chunking utils for AST chunking; re-export in apps; CLI imports packaged utils

* style

* chore: fix ruff warnings (RUF059, F401)

* style
2025-09-17 18:44:00 -07:00
GitHub Actions
c5a29f849a chore: release v0.3.4 2025-09-16 20:45:22 +00:00
Yichuan Wang
3b8dc6368e Ast fork (#92) 2025-09-08 18:43:31 -07:00
Aiden Huang
e309f292de docs(mcp): add root llms.txt for MCP discovery; update MCP README to reference it; refs #76 (#91) 2025-09-07 14:39:58 -07:00
AWS Mcleod
0d9f92ea0f Add grep search functionality - Issue #86 (#87)
* Add grep search functionality to LeannSearcher

- Add use_grep parameter to search method
- Implement grep-based search on .jsonl files
- Add fallback Python regex search
- Support same SearchResult format as semantic search

Addresses issue #86

* fix: resolve linting errors

* docs: add grep search example

* docs: add grep search to README examples

* refactor: remove regex fallback, move grep example to features section

* docs: add grep search to Advanced Features with comprehensive guide
2025-09-05 13:48:07 -07:00
GitHub Actions
b0b353d279 chore: release v0.3.3 2025-09-02 21:29:56 +00:00
Andy Lee
4dffdfedbe feat: Add ARM64 Linux wheel support for leann-backend-hnsw (#83)
* feat: Add ARM64 Linux wheel support for leann-backend-hnsw

* fix: Use OpenBLAS for ARM64 Linux builds instead of Intel MKL

* fix: Configure Faiss with SVE optimization for ARM64 builds

- Set FAISS_OPT_LEVEL to "sve" for ARM64 architecture
- Disable x86-specific SIMD instructions (AVX2, AVX512, SSE4.1)
- Use ARM64-native SVE optimization as per Faiss conda build scripts
- Add architecture detection and proper configuration messages

Fixes compilation error: "xmmintrin.h: No such file or directory"
on ubuntu-24.04-arm runners.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Apply ARM64 compatibility fix directly to Faiss submodule

- Modify faiss/impl/pq.cpp to use x86-specific preprocessor conditions
- Remove patch file approach in favor of direct submodule modification
- Update CMakeLists.txt to reflect the submodule changes
- Fixes ARM64 Linux compilation by preventing x86 SIMD header inclusion

This resolves the "xmmintrin.h: No such file or directory" error
when building ARM64 Linux wheels for Docker compatibility.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: Update Faiss submodule to include ARM64 compatibility fix

- Points to commit ed96ff7d with x86-specific preprocessor conditions
- Enables successful ARM64 Linux wheel builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* retrigger ci

* fix: Use different optimization levels for ARM64 based on platform

- Use SVE optimization only for ARM64 Linux
- Use generic optimization for ARM64 macOS to avoid clang SVE issues
- Fixes macOS ARM64 compilation errors with SVE instructions

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* feat: Update DiskANN submodule with OpenBLAS fallback support

- Points to commit 5c396c4 with ARM64 Linux OpenBLAS support
- Enables DiskANN to build on ARM64 Linux using standard BLAS libraries
- Resolves Intel MKL dependency issues for Docker ARM64 deployments

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with ZeroMQ polling configuration

- Points to commit 3a1016e with explicit polling method setup
- Resolves ZeroMQ autodetection issues on ARM64 Linux
- Ensures stable cross-platform ZeroMQ builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* retrigger ci

* fix: Update DiskANN submodule with ARM64 compiler flags fix

- Points to commit a0dc600 with architecture-specific compiler flags
- Removes x86 SIMD flags on ARM64 Linux to fix compilation errors
- Enables successful ARM64 Linux wheel builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with ARM64 compiler flags fix

- Points to commit 0921664 with architecture-specific compiler flags
- Removes x86 SIMD flags on ARM64 Linux to fix compilation errors
- Enables successful ARM64 Linux wheel builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* retrigger ci

* fix: Update DiskANN submodule with cross-platform prefetch support

- Points to commit 39192d6 with unified prefetch macros
- Replaces all Intel-specific _mm_prefetch calls with cross-platform macros
- Enables ARM64 Linux compatibility while maintaining x86 performance
- Resolves all remaining compilation errors for ARM64 builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with corrected ARM64 compatibility fixes

- Points to commit 3cb87a8 with proper x86 platform detection
- Includes ARM64 fallback for AVXDistanceInnerProductFloat function
- Resolves all remaining '__m256 was not declared' compilation errors
- Enables successful ARM64 Linux wheel builds for Docker compatibility

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with template type handling fix

- Points to commit d396bc3 with corrected template type handling
- Fixes DistanceInnerProduct template instantiation for int8_t/uint8_t types
- Resolves 'cannot convert const signed char* to const float*' error
- Completes ARM64 Linux compilation compatibility

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with DistanceFastL2::norm template fix

- Points to commit 69d9a99 with corrected template type handling
- Fixes DistanceFastL2::norm template instantiation for int8_t/uint8_t types
- Resolves another 'cannot convert const signed char* to const float*' error
- Continues ARM64 Linux compilation compatibility improvements

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with LAPACKE header detection

- Points to commit 64a9e01 with LAPACKE header path configuration
- Adds pkg-config based detection for LAPACKE include directories
- Resolves 'lapacke.h: No such file or directory' compilation error
- Completes OpenBLAS integration for ARM64 Linux builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with enhanced LAPACKE header detection

- Points to commit 18d0721 with fallback LAPACKE header search paths
- Checks multiple standard locations for lapacke.h on various systems
- Improves ARM64 Linux compatibility for OpenBLAS builds
- Should resolve 'lapacke.h: No such file or directory' errors

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Add liblapacke-dev package for ARM64 Linux builds

- Add liblapacke-dev to ARM64 dependencies alongside libopenblas-dev
- Provides lapacke.h header file needed for LAPACK C interface
- Fixes 'lapacke.h: No such file or directory' compilation error
- Enables complete OpenBLAS + LAPACKE support for ARM64 wheel builds

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with cosine_similarity.h x86 intrinsics fix

- Points to commit dbb17eb with corrected conditional compilation
- Fixes immintrin.h inclusion for ARM64 compatibility in cosine_similarity.h
- Resolves 'immintrin.h: No such file or directory' error
- Continues systematic ARM64 Linux compilation fixes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: Update DiskANN submodule with LAPACKE library linking fix

- Points to commit 19f9603 with explicit LAPACKE library discovery and linking
- Resolves 'undefined symbol: LAPACKE_sgesdd' runtime error on ARM64 Linux
- Completes ARM64 Linux wheel build compatibility for Docker deployments

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-09-02 14:27:06 -07:00
Yichuan Wang
d41e467df9 [CLI] More robust leann list and leann build (#84)
* chore(submodule): bump faiss to latest storage-efficient build

* [chore] add slack to share use case

* [cli] better gitignore / better leann list

* [cli] fix # 81
2025-09-01 18:36:27 -07:00
yichuan520030910320
4ca0489cb1 [chore] add slack to share use case 2025-09-01 13:31:16 -07:00
yichuan520030910320
e83a671918 chore(submodule): bump faiss to latest storage-efficient build 2025-09-01 13:31:12 -07:00
yichuan520030910320
4e5b73ce7b fix bug introduce in #58 2025-08-22 02:35:09 -07:00
Gabriel Dehan
31b4973141 Metadata filtering feature (#75)
* Metadata filtering initial version

* Metadata filtering initial version

* Fixes linter issues

* Cleanup code

* Clean up and readme

* Fix after review

* Use UV in example

* Merge main into feature/metadata-filtering
2025-08-20 19:57:56 -07:00
Yichuan Wang
dde2221513 [EXP] Update the benchmark code (#71)
* chore(hnsw): reorder imports to satisfy ruff I001

* chore: sync changes; fix Ruff import order; update examples, benchmarks, and dependencies

- Fix import order in packages/leann-backend-hnsw/leann_backend_hnsw/hnsw_backend.py (Ruff I001)

- Update benchmarks/run_evaluation.py

- Update apps/base_rag_example.py and leann-core API usage

- Add benchmarks/data/README.md

- Update uv.lock

- Misc cleanup

- Note: added paru-bin as an embedded git repo; consider making it a submodule (git rm --cached paru-bin) if unintended

* chore: remove unintended embedded repo paru-bin and ignore it

Fix CI: avoid missing .gitmodules entry by removing gitlink and adding to .gitignore.

* ci: retrigger after removing unintended gitlink (paru-bin)

* feat(benchmarks): add --batch-size option and plumb through to HNSW search (default 0)

* feat(hnsw): add batch_size to LeannSearcher.search and LeannChat.ask; forward only for HNSW backend

* chore(logging): surface recompute and batching params; enable INFO logging in benchmark

* feat(embeddings): add optional manual tokenization path (HF tokenizer+model) with mean pooling; default remains SentenceTransformer.encode

* fix micro bench and fix pre commit

* update readme

---------

Co-authored-by: yichuan-w <yichuan-w@users.noreply.github.com>
2025-08-20 17:31:46 -07:00
Andy Lee
6d11e86e71 Run Evaluation RPJ Wiki on Arch Linux (#74)
* chore: ignore benchmark data

* perf: avoid merging offset dicts for lower mem usage

* style: format

* docs: rpj_wiki
2025-08-20 12:25:54 -07:00
53 changed files with 7418 additions and 4061 deletions

50
.github/ISSUE_TEMPLATE/bug_report.yml vendored Normal file
View File

@@ -0,0 +1,50 @@
name: Bug Report
description: Report a bug in LEANN
labels: ["bug"]
body:
- type: textarea
id: description
attributes:
label: What happened?
description: A clear description of the bug
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: How to reproduce
placeholder: |
1. Install with...
2. Run command...
3. See error
validations:
required: true
- type: textarea
id: error
attributes:
label: Error message
description: Paste any error messages
render: shell
- type: input
id: version
attributes:
label: LEANN Version
placeholder: "0.1.0"
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating System
options:
- macOS
- Linux
- Windows
- Docker
validations:
required: true

8
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
blank_issues_enabled: true
contact_links:
- name: Documentation
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
about: Read the docs first
- name: Discussions
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
about: Ask questions and share ideas

View File

@@ -0,0 +1,27 @@
name: Feature Request
description: Suggest a new feature for LEANN
labels: ["enhancement"]
body:
- type: textarea
id: problem
attributes:
label: What problem does this solve?
description: Describe the problem or need
validations:
required: true
- type: textarea
id: solution
attributes:
label: Proposed solution
description: How would you like this to work?
validations:
required: true
- type: textarea
id: example
attributes:
label: Example usage
description: Show how the API might look
render: python

13
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,13 @@
## What does this PR do?
<!-- Brief description of your changes -->
## Related Issues
Fixes #
## Checklist
- [ ] Tests pass (`uv run pytest`)
- [ ] Code formatted (`ruff format` and `ruff check`)
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)

View File

@@ -54,6 +54,17 @@ jobs:
python: '3.12'
- os: ubuntu-22.04
python: '3.13'
# ARM64 Linux builds
- os: ubuntu-24.04-arm
python: '3.9'
- os: ubuntu-24.04-arm
python: '3.10'
- os: ubuntu-24.04-arm
python: '3.11'
- os: ubuntu-24.04-arm
python: '3.12'
- os: ubuntu-24.04-arm
python: '3.13'
- os: macos-14
python: '3.9'
- os: macos-14
@@ -108,13 +119,46 @@ jobs:
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
patchelf
# Install Intel MKL for DiskANN
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
# Debug: Show system information
echo "🔍 System Information:"
echo "Architecture: $(uname -m)"
echo "OS: $(uname -a)"
echo "CPU info: $(lscpu | head -5)"
# Install math library based on architecture
ARCH=$(uname -m)
echo "🔍 Setting up math library for architecture: $ARCH"
if [[ "$ARCH" == "x86_64" ]]; then
# Install Intel MKL for DiskANN on x86_64
echo "📦 Installing Intel MKL for x86_64..."
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
source /opt/intel/oneapi/setvars.sh
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
echo "✅ Intel MKL installed for x86_64"
# Debug: Check MKL installation
echo "🔍 MKL Installation Check:"
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
elif [[ "$ARCH" == "aarch64" ]]; then
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
echo "📦 Installing OpenBLAS for ARM64..."
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
echo "✅ OpenBLAS installed for ARM64"
# Debug: Check OpenBLAS installation
echo "🔍 OpenBLAS Installation Check:"
dpkg -l | grep openblas || echo "OpenBLAS package not found"
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
fi
# Debug: Show final library paths
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
- name: Install system dependencies (macOS)
if: runner.os == 'macOS'

6
.gitignore vendored
View File

@@ -22,6 +22,7 @@ demo/experiment_results/**/*.json
*.sh
*.txt
!CMakeLists.txt
!llms.txt
latency_breakdown*.json
experiment_results/eval_results/diskann/*.json
aws/
@@ -93,5 +94,10 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
batchtest.py
tests/__pytest_cache__/
tests/__pycache__/
paru-bin/
CLAUDE.md
CLAUDE.local.md
.claude/*.local.*
.claude/local/*
benchmarks/data/

3
.gitmodules vendored
View File

@@ -14,3 +14,6 @@
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
path = packages/leann-backend-hnsw/third_party/libzmq
url = https://github.com/zeromq/libzmq.git
[submodule "packages/astchunk-leann"]
path = packages/astchunk-leann
url = https://github.com/yichuan-w/astchunk-leann.git

View File

@@ -13,4 +13,5 @@ repos:
rev: v0.12.7 # Fixed version to match pyproject.toml
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

View File

@@ -8,6 +8,8 @@
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
</p>
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
@@ -176,8 +178,7 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
**AST-Aware Code Chunking** - LEANN also features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files, providing improved code understanding compared to traditional text-based approaches.
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md) to learn more.
### Generation Model Setup
@@ -221,7 +222,8 @@ ollama pull llama3.2:1b
</details>
### ⭐ Flexible Configuration
## ⭐ Flexible Configuration
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
@@ -477,6 +479,15 @@ Once the index is built, you can ask questions like:
### 🚀 Claude Code Integration: Transform Your Development Workflow!
<details>
<summary><strong>NEW!! ASTAware Code Chunking</strong></summary>
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
</details>
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
**Key features:**
@@ -535,6 +546,9 @@ leann search my-docs "machine learning concepts"
# Interactive chat with your documents
leann ask my-docs --interactive
# Ask a single question (non-interactive)
leann ask my-docs "Where are prompts configured?"
# List all your indexes
leann list
@@ -618,6 +632,46 @@ Options:
</details>
## 🚀 Advanced Features
### 🎯 Metadata Filtering
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
```python
# Add metadata during indexing
builder.add_text(
"def authenticate_user(token): ...",
metadata={"file_extension": ".py", "lines_of_code": 25}
)
# Search with filters
results = searcher.search(
query="authentication function",
metadata_filters={
"file_extension": {"==": ".py"},
"lines_of_code": {"<": 100}
}
)
```
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
### 🔍 Grep Search
For exact text matching instead of semantic search, use the `use_grep` parameter:
```python
# Exact text search
results = searcher.search("bananacrocodile", use_grep=True, top_k=1)
```
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
📖 **[Complete grep search guide →](docs/grep_search.md)**
## 🏗️ Architecture & How It Works
<p align="center">
@@ -697,6 +751,9 @@ MIT License - see [LICENSE](LICENSE) for details.
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
We welcome more contributors! Feel free to open issues or submit PRs.
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).

View File

@@ -11,6 +11,7 @@ from typing import Any
import dotenv
from leann.api import LeannBuilder, LeannChat
from leann.registry import register_project_directory
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
dotenv.load_dotenv()
@@ -78,6 +79,24 @@ class BaseRAGExample(ABC):
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
)
embedding_group.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
embedding_group.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
embedding_group.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
# LLM parameters
llm_group = parser.add_argument_group("LLM Parameters")
@@ -97,8 +116,8 @@ class BaseRAGExample(ABC):
llm_group.add_argument(
"--llm-host",
type=str,
default="http://localhost:11434",
help="Host for Ollama API (default: http://localhost:11434)",
default=None,
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
llm_group.add_argument(
"--thinking-budget",
@@ -107,6 +126,18 @@ class BaseRAGExample(ABC):
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
llm_group.add_argument(
"--llm-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs",
)
llm_group.add_argument(
"--llm-api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# AST Chunking parameters
ast_group = parser.add_argument_group("AST Chunking Parameters")
@@ -205,9 +236,13 @@ class BaseRAGExample(ABC):
if args.llm == "openai":
config["model"] = args.llm_model or "gpt-4o"
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
resolved_key = resolve_openai_api_key(args.llm_api_key)
if resolved_key:
config["api_key"] = resolved_key
elif args.llm == "ollama":
config["model"] = args.llm_model or "llama3.2:1b"
config["host"] = args.llm_host
config["host"] = resolve_ollama_host(args.llm_host)
elif args.llm == "hf":
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
elif args.llm == "simulated":
@@ -223,10 +258,20 @@ class BaseRAGExample(ABC):
print(f"\n[Building Index] Creating {self.name} index...")
print(f"Total text chunks: {len(texts)}")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend_name,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree,
complexity=args.build_complexity,
is_compact=not args.no_compact,
@@ -299,7 +344,6 @@ class BaseRAGExample(ABC):
chat = LeannChat(
index_path,
llm_config=self.get_llm_config(args),
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
complexity=args.search_complexity,
)

View File

@@ -10,7 +10,8 @@ from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample, create_text_chunks
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .history_data.history import ChromeHistoryReader

View File

@@ -1,16 +1,38 @@
"""
Chunking utilities for LEANN RAG applications.
Provides AST-aware and traditional text chunking functionality.
"""Unified chunking utilities facade.
This module re-exports the packaged utilities from `leann.chunking_utils` so
that both repo apps (importing `chunking`) and installed wheels share one
single implementation. When running from the repo without installation, it
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
"""
from .utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
import sys
from pathlib import Path
try:
from leann.chunking_utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
except Exception: # pragma: no cover - best-effort fallback for dev environment
repo_root = Path(__file__).resolve().parents[2]
leann_src = repo_root / "packages" / "leann-core" / "src"
if leann_src.exists():
sys.path.insert(0, str(leann_src))
from leann.chunking_utils import (
CODE_EXTENSIONS,
create_ast_chunks,
create_text_chunks,
create_traditional_chunks,
detect_code_files,
get_language_from_extension,
)
else:
raise
__all__ = [
"CODE_EXTENSIONS",

View File

@@ -9,7 +9,8 @@ from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent))
from base_rag_example import BaseRAGExample, create_text_chunks
from base_rag_example import BaseRAGExample
from chunking import create_text_chunks
from .email_data.LEANN_email_reader import EmlxReader

View File

@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
if count >= max_count and max_count > 0:
break
last_visit, url, title, visit_count, typed_count, hidden = row
last_visit, url, title, visit_count, typed_count, _hidden = row
# Create document content with metadata embedded in text
doc_content = f"""

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 152 KiB

44
benchmarks/data/README.md Executable file
View File

@@ -0,0 +1,44 @@
---
license: mit
---
# LEANN-RAG Evaluation Data
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
## Dataset Components
This dataset is structured into three main parts:
1. **Pre-built LEANN Indices**:
* `dpr/`: A pre-built index for the DPR dataset.
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
2. **Ground Truth Data**:
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
3. **Queries**:
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
## Usage
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
```bash
pip install huggingface-hub
```
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
```python
from huggingface_hub import snapshot_download
snapshot_download(
repo_id="LEANN-RAG/leann-rag-evaluation-data",
repo_type="dataset",
local_dir="data"
)
```
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.

View File

@@ -12,7 +12,7 @@ import time
from pathlib import Path
import numpy as np
from leann.api import LeannBuilder, LeannSearcher
from leann.api import LeannBuilder, LeannChat, LeannSearcher
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
@@ -197,6 +197,25 @@ def main():
parser.add_argument(
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
)
parser.add_argument(
"--batch-size",
type=int,
default=0,
help="Batch size for HNSW batched search (0 disables batching)",
)
parser.add_argument(
"--llm-type",
type=str,
choices=["ollama", "hf", "openai", "gemini", "simulated"],
default="ollama",
help="LLM backend type to optionally query during evaluation (default: ollama)",
)
parser.add_argument(
"--llm-model",
type=str,
default="qwen3:1.7b",
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
)
args = parser.parse_args()
# --- Path Configuration ---
@@ -318,9 +337,24 @@ def main():
for i in range(num_eval_queries):
start_time = time.time()
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
new_results = searcher.search(
queries[i],
top_k=args.top_k,
complexity=args.ef_search,
batch_size=args.batch_size,
)
search_times.append(time.time() - start_time)
# Optional: also call the LLM with configurable backend/model (does not affect recall)
llm_config = {"type": args.llm_type, "model": args.llm_model}
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
answer = chat.ask(
queries[i],
top_k=args.top_k,
complexity=args.ef_search,
batch_size=args.batch_size,
)
print(f"Answer: {answer}")
# Correct Recall Calculation: Based on TEXT content
new_texts = {result.text for result in new_results}

View File

@@ -20,7 +20,7 @@ except ImportError:
@dataclass
class BenchmarkConfig:
model_path: str = "facebook/contriever"
model_path: str = "facebook/contriever-msmarco"
batch_sizes: list[int] = None
seq_length: int = 256
num_runs: int = 5
@@ -34,7 +34,7 @@ class BenchmarkConfig:
def __post_init__(self):
if self.batch_sizes is None:
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
class MLXBenchmark:
@@ -179,11 +179,14 @@ class Benchmark:
def _run_inference(self, input_ids: torch.Tensor) -> float:
attention_mask = torch.ones_like(input_ids)
# print shape of input_ids and attention_mask
print(f"input_ids shape: {input_ids.shape}")
print(f"attention_mask shape: {attention_mask.shape}")
start_time = time.time()
with torch.no_grad():
self.model(input_ids=input_ids, attention_mask=attention_mask)
# mps sync
if torch.cuda.is_available():
torch.cuda.synchronize()
if torch.backends.mps.is_available():
torch.mps.synchronize()
end_time = time.time()

View File

@@ -26,6 +26,21 @@ leann build my-code-index --docs ./src --use-ast-chunking
uv pip install -e "."
```
#### For normal users (PyPI install)
- Use `pip install leann` or `uv pip install leann`.
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
#### For developers (from source, editable)
```bash
git clone https://github.com/yichuan-w/LEANN.git leann
cd leann
git submodule update --init --recursive
uv sync
```
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
## Best Practices
### When to Use AST Chunking

View File

@@ -83,6 +83,81 @@ ollama pull nomic-embed-text
</details>
## Local & Remote Inference Endpoints
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint either on the same machine or across the network with a couple of flags or environment variables.
### One-Time Environment Setup
```bash
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
export OPENAI_BASE_URL="http://localhost:1234/v1"
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
```
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
### Passing Hosts Per Command
```bash
# Build an index with a remote embedding server
leann build my-notes \
--docs ./notes \
--embedding-mode openai \
--embedding-model text-embedding-qwen3-embedding-0.6b \
--embedding-api-base http://192.168.1.50:1234/v1 \
--embedding-api-key local-dev-key
# Query using a local LM Studio instance via OpenAI-compatible API
leann ask my-notes \
--llm openai \
--llm-model qwen3-8b \
--api-base http://localhost:1234/v1 \
--api-key local-dev-key
# Query an Ollama instance running on another box
leann ask my-notes \
--llm ollama \
--llm-model qwen3:14b \
--host http://192.168.1.101:11434
```
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
- Configure router or cloud provider port forwarding.
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
**Tip:** If your runtime does not require an API key (many local stacks dont), leave `--api-key` unset. LEANN will skip injecting credentials.
### Python API Usage
You can pass the same configuration from Python:
```python
from leann.api import LeannBuilder
builder = LeannBuilder(
backend_name="hnsw",
embedding_mode="openai",
embedding_model="text-embedding-qwen3-embedding-0.6b",
embedding_options={
"base_url": "http://192.168.1.50:1234/v1",
"api_key": "local-dev-key",
},
)
builder.build_index("./indexes/my-notes", chunks)
```
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
## Index Selection: Matching Your Scale
### HNSW (Hierarchical Navigable Small World)

149
docs/grep_search.md Normal file
View File

@@ -0,0 +1,149 @@
# LEANN Grep Search Usage Guide
## Overview
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
## Basic Usage
### Simple Grep Search
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("your_index_path")
# Exact text search
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:100]}...")
print("-" * 40)
```
### Comparison: Semantic vs Grep Search
```python
# Semantic search - finds conceptually similar content
semantic_results = searcher.search("machine learning algorithms", top_k=3)
# Grep search - finds exact text matches
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
```
## When to Use Grep Search
### Use Cases
- **Code Search**: Finding specific function definitions, class names, or variable references
- **Error Debugging**: Locating exact error messages or stack traces
- **Documentation**: Finding specific API endpoints or exact terminology
### Examples
```python
# Find function definitions
functions = searcher.search("def __init__", use_grep=True)
# Find import statements
imports = searcher.search("from sklearn import", use_grep=True)
# Find specific error types
errors = searcher.search("FileNotFoundError", use_grep=True)
# Find TODO comments
todos = searcher.search("TODO:", use_grep=True)
# Find configuration entries
configs = searcher.search("server_port=", use_grep=True)
```
## Technical Details
### How It Works
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
3. **Result Processing**: Parses JSON lines and extracts text and metadata
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
### Search Process
```
Query: "def train_model"
grep -i -n "def train_model" documents.leann.passages.jsonl
Parse matching JSON lines
Calculate scores based on term frequency
Return top_k results
```
### Scoring Algorithm
```python
# Term frequency in document
score = text.lower().count(query.lower())
```
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
## Error Handling
### Common Issues
#### Grep Command Not Found
```
RuntimeError: grep command not found. Please install grep or use semantic search.
```
**Solution**: Install grep on your system:
- **Ubuntu/Debian**: `sudo apt-get install grep`
- **macOS**: grep is pre-installed
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
#### No Results Found
```python
# Check if your query exists in the raw data
results = searcher.search("your_query", use_grep=True)
if not results:
print("No exact matches found. Try:")
print("1. Check spelling and case")
print("2. Use partial terms")
print("3. Switch to semantic search")
```
## Complete Example
```python
#!/usr/bin/env python3
"""
Grep Search Example
Demonstrates grep search for exact text matching.
"""
from leann.api import LeannSearcher
def demonstrate_grep_search():
# Initialize searcher
searcher = LeannSearcher("my_index")
print("=== Function Search ===")
functions = searcher.search("def __init__", use_grep=True, top_k=5)
for i, result in enumerate(functions, 1):
print(f"{i}. Score: {result.score}")
print(f" Preview: {result.text[:60]}...")
print()
print("=== Error Search ===")
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
for result in errors:
print(f"Content: {result.text.strip()}")
print("-" * 40)
if __name__ == "__main__":
demonstrate_grep_search()
```

300
docs/metadata_filtering.md Normal file
View File

@@ -0,0 +1,300 @@
# LEANN Metadata Filtering Usage Guide
## Overview
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
## Basic Usage
### Adding Metadata to Your Documents
When building your index, add metadata to each text chunk:
```python
from leann.api import LeannBuilder
builder = LeannBuilder("hnsw")
# Add text with metadata
builder.add_text(
text="Chapter 1: Alice falls down the rabbit hole",
metadata={
"chapter": 1,
"character": "Alice",
"themes": ["adventure", "curiosity"],
"word_count": 150
}
)
builder.build_index("alice_in_wonderland_index")
```
### Searching with Metadata Filters
Use the `metadata_filters` parameter in search calls:
```python
from leann.api import LeannSearcher
searcher = LeannSearcher("alice_in_wonderland_index")
# Search with filters
results = searcher.search(
query="What happens to Alice?",
top_k=10,
metadata_filters={
"chapter": {"<=": 5}, # Only chapters 1-5
"spoiler_level": {"!=": "high"} # No high spoilers
}
)
```
## Filter Syntax
### Basic Structure
```python
metadata_filters = {
"field_name": {"operator": value},
"another_field": {"operator": value}
}
```
### Supported Operators
#### Comparison Operators
- `"=="`: Equal to
- `"!="`: Not equal to
- `"<"`: Less than
- `"<="`: Less than or equal
- `">"`: Greater than
- `">="`: Greater than or equal
```python
# Examples
{"chapter": {"==": 1}} # Exactly chapter 1
{"page": {">": 100}} # Pages after 100
{"rating": {">=": 4.0}} # Rating 4.0 or higher
{"word_count": {"<": 500}} # Short passages
```
#### Membership Operators
- `"in"`: Value is in list
- `"not_in"`: Value is not in list
```python
# Examples
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
```
#### String Operators
- `"contains"`: String contains substring
- `"starts_with"`: String starts with prefix
- `"ends_with"`: String ends with suffix
```python
# Examples
{"title": {"contains": "alice"}} # Title contains "alice"
{"filename": {"ends_with": ".py"}} # Python files
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
```
#### Boolean Operators
- `"is_true"`: Field is truthy
- `"is_false"`: Field is falsy
```python
# Examples
{"is_published": {"is_true": True}} # Published content
{"is_draft": {"is_false": False}} # Not drafts
```
### Multiple Operators on Same Field
You can apply multiple operators to the same field (AND logic):
```python
metadata_filters = {
"word_count": {
">=": 100, # At least 100 words
"<=": 500 # At most 500 words
}
}
```
### Compound Filters
Multiple fields are combined with AND logic:
```python
metadata_filters = {
"chapter": {"<=": 10}, # Up to chapter 10
"character": {"==": "Alice"}, # About Alice
"spoiler_level": {"!=": "high"} # No major spoilers
}
```
## Use Case Examples
### 1. Spoiler-Free Book Search
```python
# Reader has only read up to chapter 5
def search_spoiler_free(query, max_chapter):
return searcher.search(
query=query,
metadata_filters={
"chapter": {"<=": max_chapter},
"spoiler_level": {"in": ["none", "low"]}
}
)
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
```
### 2. Document Management by Date
```python
# Find recent documents
recent_docs = searcher.search(
query="project updates",
metadata_filters={
"date": {">=": "2024-01-01"},
"document_type": {"==": "report"}
}
)
```
### 3. Code Search by File Type
```python
# Search only Python files
python_code = searcher.search(
query="authentication function",
metadata_filters={
"file_extension": {"==": ".py"},
"lines_of_code": {"<": 100}
}
)
```
### 4. Content Filtering by Audience
```python
# Age-appropriate content
family_content = searcher.search(
query="adventure stories",
metadata_filters={
"age_rating": {"in": ["G", "PG"]},
"content_warnings": {"not_in": ["violence", "adult_themes"]}
}
)
```
### 5. Multi-Book Series Management
```python
# Search across first 3 books only
early_series = searcher.search(
query="character development",
metadata_filters={
"series": {"==": "Harry Potter"},
"book_number": {"<=": 3}
}
)
```
## Running the Example
You can see metadata filtering in action with our spoiler-free book RAG example:
```bash
# Don't forget to set up the environment
uv venv
source .venv/bin/activate
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
export OPENAI_API_KEY="your-api-key-here"
# Run the spoiler-free book RAG example
uv run examples/spoiler_free_book_rag.py
```
This example demonstrates:
- Building an index with metadata (chapter numbers, characters, themes, locations)
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
- Different scenarios for readers at various points in the book
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
## Advanced Patterns
### Custom Chunking with metadata
```python
def chunk_book_with_metadata(book_text, book_info):
chunks = []
for chapter_num, chapter_text in parse_chapters(book_text):
# Extract entities, themes, etc.
characters = extract_characters(chapter_text)
themes = classify_themes(chapter_text)
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
# Create chunks with rich metadata
for paragraph in split_paragraphs(chapter_text):
chunks.append({
"text": paragraph,
"metadata": {
"book_title": book_info["title"],
"chapter": chapter_num,
"characters": characters,
"themes": themes,
"spoiler_level": spoiler_level,
"word_count": len(paragraph.split()),
"reading_level": calculate_reading_level(paragraph)
}
})
return chunks
```
## Performance Considerations
### Efficient Filtering Strategies
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
### Best Practices
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
### Adding Metadata to Existing Indices
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
```python
# Read existing passages and add metadata
def add_metadata_to_existing_chunks(chunks):
for chunk in chunks:
# Extract or assign metadata based on content
chunk["metadata"] = extract_metadata(chunk["text"])
return chunks
# Rebuild index with metadata
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
builder = LeannBuilder("hnsw")
for chunk in enhanced_chunks:
builder.add_text(chunk["text"], chunk["metadata"])
builder.build_index("enhanced_index")
```

0
examples/__init__.py Normal file
View File

View File

@@ -0,0 +1,404 @@
"""Dynamic HNSW update demo without compact storage.
This script reproduces the minimal scenario we used while debugging on-the-fly
recompute:
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
2. Print the top results with `recompute_embeddings=True`.
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
4. Run the same query again to show the newly inserted passages.
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
ZMQ activity)::
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
uv run -m examples.dynamic_update_no_recompute \
--index-path .leann/examples/leann-demo.leann
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
It issues the query "What's LEANN?" before and after the update to show how the
new passages become immediately searchable. The script uses the
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
freshly added passages are embedded locally just like the initial build.
To make storage comparisons easy, the script can also build a matching
``is_recompute=False`` baseline (enabled by default) and report the index size
delta after the update. Disable the baseline run with
``--skip-compare-no-recompute`` if you only need the recompute flow.
"""
import argparse
import json
from collections.abc import Iterable
from pathlib import Path
from typing import Any
from leann.api import LeannBuilder, LeannSearcher
from leann.registry import register_project_directory
from apps.chunking import create_text_chunks
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_QUERY = "What's LEANN?"
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
def load_chunks_from_files(paths: list[Path]) -> list[str]:
from llama_index.core import SimpleDirectoryReader
documents = []
for path in paths:
p = path.expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input path not found: {p}")
if p.is_dir():
reader = SimpleDirectoryReader(str(p), recursive=False)
documents.extend(reader.load_data(show_progress=True))
else:
reader = SimpleDirectoryReader(input_files=[str(p)])
documents.extend(reader.load_data(show_progress=True))
if not documents:
return []
chunks = create_text_chunks(
documents,
chunk_size=512,
chunk_overlap=128,
use_ast_chunking=False,
)
return [c for c in chunks if isinstance(c, str) and c.strip()]
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
searcher = LeannSearcher(str(index_path))
try:
return searcher.search(
query=query,
top_k=top_k,
recompute_embeddings=recompute_embeddings,
batch_size=16,
)
finally:
searcher.cleanup()
def print_results(title: str, results: Iterable) -> None:
print(f"\n=== {title} ===")
res_list = list(results)
print(f"results count: {len(res_list)}")
print("passages:")
if not res_list:
print(" (no passages returned)")
for res in res_list:
snippet = res.text.replace("\n", " ")[:120]
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
def build_initial_index(
index_path: Path,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
builder = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for idx, passage in enumerate(paragraphs):
builder.add_text(passage, metadata={"id": str(idx)})
builder.build_index(str(index_path))
def update_index(
index_path: Path,
start_id: int,
paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
) -> None:
updater = LeannBuilder(
backend_name="hnsw",
embedding_model=model_name,
embedding_mode=embedding_mode,
is_compact=False,
is_recompute=is_recompute,
)
for offset, passage in enumerate(paragraphs, start=start_id):
updater.add_text(passage, metadata={"id": str(offset)})
updater.update_index(str(index_path))
def ensure_index_dir(index_path: Path) -> None:
index_path.parent.mkdir(parents=True, exist_ok=True)
def cleanup_index_files(index_path: Path) -> None:
"""Remove leftover index artifacts for a clean rebuild."""
parent = index_path.parent
if not parent.exists():
return
stem = index_path.stem
for file in parent.glob(f"{stem}*"):
if file.is_file():
file.unlink()
def index_file_size(index_path: Path) -> int:
"""Return the size of the primary .index file for the given index path."""
index_file = index_path.parent / f"{index_path.stem}.index"
return index_file.stat().st_size if index_file.exists() else 0
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
meta_path = index_path.parent / f"{index_path.name}.meta.json"
if not meta_path.exists():
return None
try:
return json.loads(meta_path.read_text())
except json.JSONDecodeError:
return None
def run_workflow(
*,
label: str,
index_path: Path,
initial_paragraphs: list[str],
update_paragraphs: list[str],
model_name: str,
embedding_mode: str,
is_recompute: bool,
query: str,
top_k: int,
) -> dict[str, Any]:
prefix = f"[{label}] " if label else ""
ensure_index_dir(index_path)
cleanup_index_files(index_path)
print(f"{prefix}Building initial index...")
build_initial_index(
index_path,
initial_paragraphs,
model_name,
embedding_mode,
is_recompute=is_recompute,
)
initial_size = index_file_size(index_path)
before_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
print(f"\n{prefix}Updating index with additional passages...")
update_index(
index_path,
start_id=len(initial_paragraphs),
paragraphs=update_paragraphs,
model_name=model_name,
embedding_mode=embedding_mode,
is_recompute=is_recompute,
)
after_results = run_search(
index_path,
query,
top_k,
recompute_embeddings=is_recompute,
)
updated_size = index_file_size(index_path)
return {
"initial_size": initial_size,
"updated_size": updated_size,
"delta": updated_size - initial_size,
"before_results": before_results,
"after_results": after_results,
"metadata": load_metadata_snapshot(index_path),
}
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--initial-files",
type=Path,
nargs="+",
default=DEFAULT_INITIAL_FILES,
help="Initial document files (PDF/TXT) used to build the base index",
)
parser.add_argument(
"--index-path",
type=Path,
default=Path(".leann/examples/leann-demo.leann"),
help="Destination index path (default: .leann/examples/leann-demo.leann)",
)
parser.add_argument(
"--initial-count",
type=int,
default=8,
help="Number of chunks to use from the initial documents (default: 8)",
)
parser.add_argument(
"--update-files",
type=Path,
nargs="*",
default=DEFAULT_UPDATE_FILES,
help="Additional documents to add during update (PDF/TXT)",
)
parser.add_argument(
"--update-count",
type=int,
default=4,
help="Number of chunks to append from update documents (default: 4)",
)
parser.add_argument(
"--update-text",
type=str,
default=(
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
"on demand to keep disk usage minimal."
),
help="Fallback text to append if --update-files is omitted",
)
parser.add_argument(
"--top-k",
type=int,
default=4,
help="Number of results to show for each search (default: 4)",
)
parser.add_argument(
"--query",
type=str,
default=DEFAULT_QUERY,
help="Query to run before/after the update",
)
parser.add_argument(
"--embedding-model",
type=str,
default="sentence-transformers/all-MiniLM-L6-v2",
help="Embedding model name",
)
parser.add_argument(
"--embedding-mode",
type=str,
default="sentence-transformers",
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode",
)
parser.add_argument(
"--compare-no-recompute",
dest="compare_no_recompute",
action="store_true",
help="Also run a baseline with is_recompute=False and report its index growth.",
)
parser.add_argument(
"--skip-compare-no-recompute",
dest="compare_no_recompute",
action="store_false",
help="Skip building the no-recompute baseline.",
)
parser.set_defaults(compare_no_recompute=True)
args = parser.parse_args()
ensure_index_dir(args.index_path)
register_project_directory(REPO_ROOT)
initial_chunks = load_chunks_from_files(list(args.initial_files))
if not initial_chunks:
raise ValueError("No text chunks extracted from the initial files.")
initial = initial_chunks[: args.initial_count]
if not initial:
raise ValueError("Initial chunk set is empty after applying --initial-count.")
if args.update_files:
update_chunks = load_chunks_from_files(list(args.update_files))
if not update_chunks:
raise ValueError("No text chunks extracted from the update files.")
to_add = update_chunks[: args.update_count]
else:
if not args.update_text:
raise ValueError("Provide --update-files or --update-text for the update step.")
to_add = [args.update_text]
if not to_add:
raise ValueError("Update chunk set is empty after applying --update-count.")
recompute_stats = run_workflow(
label="recompute",
index_path=args.index_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=True,
query=args.query,
top_k=args.top_k,
)
print_results("initial search", recompute_stats["before_results"])
print_results("after update", recompute_stats["after_results"])
print(
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
f"{recompute_stats['delta']})"
)
if recompute_stats["metadata"]:
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if args.compare_no_recompute:
baseline_path = (
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
)
baseline_stats = run_workflow(
label="no-recompute",
index_path=baseline_path,
initial_paragraphs=initial,
update_paragraphs=to_add,
model_name=args.embedding_model,
embedding_mode=args.embedding_mode,
is_recompute=False,
query=args.query,
top_k=args.top_k,
)
print(
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
f"{baseline_stats['delta']})"
)
after_texts = [res.text for res in recompute_stats["after_results"]]
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
if after_texts == baseline_after_texts:
print(
"[no-recompute] Search results match recompute baseline; see above for the shared output."
)
else:
print("[no-recompute] WARNING: search results differ from recompute baseline.")
if baseline_stats["metadata"]:
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
print("[no-recompute] metadata snapshot:")
print(json.dumps(meta_view, indent=2))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,35 @@
"""
Grep Search Example
Shows how to use grep-based text search instead of semantic search.
Useful when you need exact text matches rather than meaning-based results.
"""
from leann import LeannSearcher
# Load your index
searcher = LeannSearcher("my-documents.leann")
# Regular semantic search
print("=== Semantic Search ===")
results = searcher.search("machine learning algorithms", top_k=3)
for result in results:
print(f"Score: {result.score:.3f}")
print(f"Text: {result.text[:80]}...")
print()
# Grep-based search for exact text matches
print("=== Grep Search ===")
results = searcher.search("def train_model", top_k=3, use_grep=True)
for result in results:
print(f"Score: {result.score}")
print(f"Text: {result.text[:80]}...")
print()
# Find specific error messages
error_results = searcher.search("FileNotFoundError", use_grep=True)
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
# Search for function definitions
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
print(f"Found {len(func_results)} class definitions")

View File

@@ -0,0 +1,250 @@
#!/usr/bin/env python3
"""
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
This example demonstrates how to use LEANN's metadata filtering to create
a spoiler-free book RAG system where users can search for information
up to a specific chapter they've read.
Usage:
python spoiler_free_book_rag.py
"""
import os
import sys
from typing import Any, Optional
# Add LEANN to path (adjust path as needed)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
from leann.api import LeannBuilder, LeannSearcher
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
"""
Create sample book chunks with metadata for demonstration.
In a real implementation, this would parse actual book files (epub, txt, etc.)
and extract chapter boundaries, character mentions, etc.
Args:
book_title: Title of the book
Returns:
List of chunk dictionaries with text and metadata
"""
# Sample book chunks with metadata
# In practice, you'd use proper text processing libraries
sample_chunks = [
{
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
"metadata": {
"book": book_title,
"chapter": 1,
"page": 1,
"characters": ["Alice", "Sister"],
"themes": ["boredom", "curiosity"],
"location": "riverbank",
},
},
{
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
"metadata": {
"book": book_title,
"chapter": 1,
"page": 2,
"characters": ["Alice", "White Rabbit"],
"themes": ["decision", "surprise", "magic"],
"location": "riverbank",
},
},
{
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
"metadata": {
"book": book_title,
"chapter": 2,
"page": 15,
"characters": ["Alice"],
"themes": ["falling", "wonder", "transformation"],
"location": "rabbit hole",
},
},
{
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
"metadata": {
"book": book_title,
"chapter": 6,
"page": 85,
"characters": ["Alice", "Cheshire Cat"],
"themes": ["madness", "philosophy", "identity"],
"location": "Duchess's house",
},
},
{
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
"metadata": {
"book": book_title,
"chapter": 8,
"page": 120,
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
"themes": ["justice", "absurdity", "authority"],
"location": "Queen's court",
},
},
{
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
"metadata": {
"book": book_title,
"chapter": 12,
"page": 180,
"characters": ["Alice", "Sister", "Rabbit"],
"themes": ["revelation", "reality", "growth"],
"location": "riverbank",
},
},
]
return sample_chunks
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
"""
Build a LEANN index with book chunks that include spoiler metadata.
Args:
book_chunks: List of book chunks with metadata
index_name: Name for the index
Returns:
Path to the built index
"""
print(f"📚 Building spoiler-free book index: {index_name}")
# Initialize LEANN builder
builder = LeannBuilder(
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
)
# Add each chunk with its metadata
for chunk in book_chunks:
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
# Build the index
index_path = f"{index_name}_book_index"
builder.build_index(index_path)
print(f"✅ Index built successfully: {index_path}")
return index_path
def spoiler_free_search(
index_path: str,
query: str,
max_chapter: int,
character_filter: Optional[list[str]] = None,
) -> list[dict[str, Any]]:
"""
Perform a spoiler-free search on the book index.
Args:
index_path: Path to the LEANN index
query: Search query
max_chapter: Maximum chapter number to include
character_filter: Optional list of characters to focus on
Returns:
List of search results safe for the reader
"""
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
searcher = LeannSearcher(index_path)
metadata_filters = {"chapter": {"<=": max_chapter}}
if character_filter:
metadata_filters["characters"] = {"contains": character_filter[0]}
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
return results
def demo_spoiler_free_rag():
"""
Demonstrate the spoiler-free book RAG system.
"""
print("🎭 Spoiler-Free Book RAG Demo")
print("=" * 40)
# Step 1: Prepare book data
book_title = "Alice's Adventures in Wonderland"
book_chunks = chunk_book_with_metadata(book_title)
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
# Step 2: Build the index (in practice, this would be done once)
try:
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
except Exception as e:
print(f"❌ Failed to build index (likely missing dependencies): {e}")
print(
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
)
return
# Step 3: Demonstrate various spoiler-free searches
search_scenarios = [
{
"description": "Reader who has only read Chapter 1",
"query": "What can you tell me about the rabbit?",
"max_chapter": 1,
},
{
"description": "Reader who has read up to Chapter 5",
"query": "Tell me about Alice's adventures",
"max_chapter": 5,
},
{
"description": "Reader who has read most of the book",
"query": "What does the Cheshire Cat represent?",
"max_chapter": 10,
},
{
"description": "Reader who has read the whole book",
"query": "What can you tell me about the rabbit?",
"max_chapter": 12,
},
]
for scenario in search_scenarios:
print(f"\n📚 Scenario: {scenario['description']}")
print(f" Query: {scenario['query']}")
try:
results = spoiler_free_search(
index_path=index_path,
query=scenario["query"],
max_chapter=scenario["max_chapter"],
)
print(f" 📄 Found {len(results)} results:")
for i, result in enumerate(results[:3], 1): # Show top 3
chapter = result.metadata.get("chapter", "?")
location = result.metadata.get("location", "?")
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
except Exception as e:
print(f" ❌ Search failed: {e}")
if __name__ == "__main__":
print("📚 LEANN Spoiler-Free Book RAG Example")
print("=====================================")
try:
demo_spoiler_free_rag()
except ImportError as e:
print(f"❌ Cannot run demo due to missing dependencies: {e}")
except Exception as e:
print(f"❌ Error running demo: {e}")

28
llms.txt Normal file
View File

@@ -0,0 +1,28 @@
# llms.txt — LEANN MCP and Agent Integration
product: LEANN
homepage: https://github.com/yichuan-w/LEANN
contact: https://github.com/yichuan-w/LEANN/issues
# Installation
install: uv tool install leann-core --with leann
# MCP Server Entry Point
mcp.server: leann_mcp
mcp.protocol_version: 2024-11-05
# Tools
mcp.tools: leann_list, leann_search
mcp.tool.leann_list.description: List available LEANN indexes
mcp.tool.leann_list.input: {}
mcp.tool.leann_search.description: Semantic search across a named LEANN index
mcp.tool.leann_search.input.index_name: string, required
mcp.tool.leann_search.input.query: string, required
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
# Notes
note: Build indexes with `leann build <name> --docs <files...>` before searching.
example.add: claude mcp add --scope user leann-server -- leann_mcp
example.verify: claude mcp list | cat

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import numpy as np
import zmq
@@ -32,6 +32,16 @@ if not logger.handlers:
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_diskann_embedding_server(
passages_file: Optional[str] = None,
zmq_port: int = 5555,
@@ -181,7 +191,12 @@ def create_diskann_embedding_server(
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
# Process embeddings using unified computation
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)
@@ -296,7 +311,12 @@ def create_diskann_embedding_server(
continue
# Process the request
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(f"Computed embeddings shape: {embeddings.shape}")
# Validation

View File

@@ -1,11 +1,11 @@
[build-system]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-diskann"
version = "0.3.2"
dependencies = ["leann-core==0.3.2", "numpy", "protobuf>=3.19.0"]
version = "0.3.4"
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
[tool.scikit-build]
# Key: simplified CMake path

View File

@@ -49,9 +49,28 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
# Disable additional SIMD versions to speed up compilation
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
# ARM64-specific configuration
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
message(STATUS "Configuring Faiss for ARM64 architecture")
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
else()
# Use generic optimization for other ARM64 platforms (like macOS)
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
endif()
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
message(STATUS "Using ARM64-compatible Faiss submodule")
endif()
# Additional optimization options from INSTALL.md
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)

View File

@@ -5,6 +5,8 @@ import os
import struct
import sys
import time
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
@@ -237,6 +239,288 @@ def write_compact_format(
f_out.write(storage_data)
@dataclass
class HNSWComponents:
original_hnsw_data: dict[str, Any]
assign_probas_np: np.ndarray
cum_nneighbor_per_level_np: np.ndarray
levels_np: np.ndarray
is_compact: bool
compact_level_ptr: Optional[np.ndarray] = None
compact_node_offsets_np: Optional[np.ndarray] = None
compact_neighbors_data: Optional[list[int]] = None
offsets_np: Optional[np.ndarray] = None
neighbors_np: Optional[np.ndarray] = None
storage_fourcc: int = NULL_INDEX_FOURCC
storage_data: bytes = b""
def _read_hnsw_structure(f) -> HNSWComponents:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
raise ValueError(
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
)
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f, "<i")
original_hnsw_data["ntotal"] = read_struct(f, "<q")
original_hnsw_data["dummy1"] = read_struct(f, "<q")
original_hnsw_data["dummy2"] = read_struct(f, "<q")
original_hnsw_data["is_trained"] = read_struct(f, "?")
original_hnsw_data["metric_type"] = read_struct(f, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
assign_probas_np = read_numpy_vector(f, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
levels_np = read_numpy_vector(f, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = read_struct(f, "<I")
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
storage_data = f.read()
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=True,
compact_level_ptr=compact_level_ptr,
compact_node_offsets_np=compact_node_offsets_np,
compact_neighbors_data=compact_neighbors_data,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
# Non-compact case
f.seek(pos_before_compact)
pos_before_probe = f.tell()
try:
suspected_flag = read_struct(f, "<B")
if suspected_flag != 0x00:
f.seek(pos_before_probe)
except EOFError:
f.seek(pos_before_probe)
offsets_np = read_numpy_vector(f, np.uint64, "Q")
neighbors_np = read_numpy_vector(f, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f, "<i")
original_hnsw_data["max_level"] = read_struct(f, "<i")
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
original_hnsw_data["efSearch"] = read_struct(f, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
storage_fourcc = NULL_INDEX_FOURCC
storage_data = b""
try:
storage_fourcc = read_struct(f, "<I")
storage_data = f.read()
except EOFError:
storage_fourcc = NULL_INDEX_FOURCC
return HNSWComponents(
original_hnsw_data=original_hnsw_data,
assign_probas_np=assign_probas_np,
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
levels_np=levels_np,
is_compact=False,
offsets_np=offsets_np,
neighbors_np=neighbors_np,
storage_fourcc=storage_fourcc,
storage_data=storage_data,
)
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
with open(path, "rb") as f:
return _read_hnsw_structure(f)
def write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
storage_fourcc,
storage_data,
):
"""Write non-compact HNSW data in original FAISS order."""
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
if original_hnsw_data["metric_type"] > 1:
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
write_numpy_vector(f_out, assign_probas_np, "d")
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
write_numpy_vector(f_out, levels_np, "i")
write_numpy_vector(f_out, offsets_np, "Q")
write_numpy_vector(f_out, neighbors_np, "i")
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
f_out.write(struct.pack("<I", storage_fourcc))
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
f_out.write(storage_data)
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
"""Rewrite an HNSW index while dropping the embedded storage section."""
start_time = time.time()
try:
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
original_hnsw_data: dict[str, Any] = {}
hnsw_index_fourcc = read_struct(f_in, "<I")
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
print(
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
file=sys.stderr,
)
return False
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
original_hnsw_data["d"] = read_struct(f_in, "<i")
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
original_hnsw_data["metric_arg"] = 0.0
if original_hnsw_data["metric_type"] > 1:
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
levels_np = read_numpy_vector(f_in, np.int32, "i")
ntotal = len(levels_np)
if ntotal != original_hnsw_data["ntotal"]:
original_hnsw_data["ntotal"] = ntotal
pos_before_compact = f_in.tell()
is_compact_flag = None
try:
is_compact_flag = read_struct(f_in, "<?")
except EOFError:
is_compact_flag = None
if is_compact_flag:
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = read_struct(f_in, "<I")
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
compact_neighbors_data = compact_neighbors_data_np.tolist()
_storage_data = f_in.read()
write_compact_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
compact_level_ptr,
compact_node_offsets_np,
compact_neighbors_data,
NULL_INDEX_FOURCC,
b"",
)
else:
f_in.seek(pos_before_compact)
pos_before_probe = f_in.tell()
try:
suspected_flag = read_struct(f_in, "<B")
if suspected_flag != 0x00:
f_in.seek(pos_before_probe)
except EOFError:
f_in.seek(pos_before_probe)
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
_storage_fourcc = None
_storage_data = b""
try:
_storage_fourcc = read_struct(f_in, "<I")
_storage_data = f_in.read()
except EOFError:
_storage_fourcc = NULL_INDEX_FOURCC
write_original_format(
f_out,
original_hnsw_data,
assign_probas_np,
cum_nneighbor_per_level_np,
levels_np,
offsets_np,
neighbors_np,
NULL_INDEX_FOURCC,
b"",
)
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
return True
except Exception as exc:
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
return False
# --- Main Conversion Logic ---
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
pass
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
"""Convenience wrapper to prune embeddings in-place."""
temp_path = f"{index_filename}.prune.tmp"
success = prune_hnsw_embeddings(index_filename, temp_path)
if success:
try:
os.replace(temp_path, index_filename)
except Exception as exc: # pragma: no cover - defensive
logger.error(f"Failed to replace original index with pruned version: {exc}")
try:
os.remove(temp_path)
except OSError:
pass
return False
else:
try:
os.remove(temp_path)
except OSError:
pass
return success
# --- Script Execution ---
if __name__ == "__main__":
parser = argparse.ArgumentParser(

View File

@@ -1,6 +1,7 @@
import logging
import os
import shutil
import time
from pathlib import Path
from typing import Any, Literal, Optional
@@ -13,7 +14,7 @@ from leann.interface import (
from leann.registry import register_backend
from leann.searcher_base import BaseSearcher
from .convert_to_csr import convert_hnsw_graph_to_csr
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
logger = logging.getLogger(__name__)
@@ -91,6 +92,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
if self.is_compact:
self._convert_to_csr(index_file)
elif self.is_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
def _convert_to_csr(self, index_file: Path):
"""Convert built index to CSR format"""
@@ -132,10 +135,10 @@ class HNSWSearcher(BaseSearcher):
if metric_enum is None:
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
self.is_compact, self.is_pruned = (
self.meta.get("is_compact", True),
self.meta.get("is_pruned", True),
)
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
index_file = self.index_dir / f"{self.index_path.stem}.index"
if not index_file.exists():
@@ -236,6 +239,7 @@ class HNSWSearcher(BaseSearcher):
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
search_time = time.time()
self._index.search(
query.shape[0],
faiss.swig_ptr(query),
@@ -244,7 +248,8 @@ class HNSWSearcher(BaseSearcher):
faiss.swig_ptr(labels),
params,
)
search_time = time.time() - search_time
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
return {"labels": string_labels, "distances": distances}

View File

@@ -10,7 +10,7 @@ import sys
import threading
import time
from pathlib import Path
from typing import Optional
from typing import Any, Optional
import msgpack
import numpy as np
@@ -24,13 +24,35 @@ logger = logging.getLogger(__name__)
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
logger.setLevel(log_level)
# Ensure we have a handler if none exists
# Ensure we have handlers if none exist
if not logger.handlers:
handler = logging.StreamHandler()
stream_handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
if log_path:
try:
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
file_formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
except Exception as exc: # pragma: no cover - best effort logging
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
logger.propagate = False
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
try:
PROVIDER_OPTIONS: dict[str, Any] = (
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
)
except json.JSONDecodeError:
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
PROVIDER_OPTIONS = {}
def create_hnsw_embedding_server(
@@ -138,7 +160,12 @@ def create_hnsw_embedding_server(
):
last_request_type = "text"
last_request_length = len(request)
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
request,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
rep_socket.send(msgpack.packb(embeddings.tolist()))
e2e_end = time.time()
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
@@ -187,7 +214,10 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(
texts, model_name, mode=embedding_mode
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
@@ -252,7 +282,12 @@ def create_hnsw_embedding_server(
if texts:
try:
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
embeddings = compute_embeddings(
texts,
model_name,
mode=embedding_mode,
provider_options=PROVIDER_OPTIONS,
)
logger.info(
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
)

View File

@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
[project]
name = "leann-backend-hnsw"
version = "0.3.2"
version = "0.3.4"
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
dependencies = [
"leann-core==0.3.2",
"leann-core==0.3.4",
"numpy",
"pyzmq>=23.0.0",
"msgpack>=1.0.0",

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann-core"
version = "0.3.2"
version = "0.3.4"
description = "Core API and plugin system for LEANN"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -6,18 +6,22 @@ with the correct, original embedding logic from the user's reference code.
import json
import logging
import pickle
import re
import subprocess
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union
import numpy as np
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
from leann.interface import LeannBackendSearcherInterface
from .chat import get_llm
from .interface import LeannBackendFactoryInterface
from .metadata_filter import MetadataFilterEngine
from .registry import BACKEND_REGISTRY
logger = logging.getLogger(__name__)
@@ -35,6 +39,7 @@ def compute_embeddings(
use_server: bool = True,
port: Optional[int] = None,
is_build=False,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Computes embeddings using different backends.
@@ -68,6 +73,7 @@ def compute_embeddings(
model_name,
mode=mode,
is_build=is_build,
provider_options=provider_options,
)
@@ -125,6 +131,7 @@ class PassageManager:
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
# per-shard maps and do a lightweight per-shard lookup on demand.
self._total_count: int = 0
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
index_name_base = None
@@ -212,6 +219,56 @@ class PassageManager:
continue
raise KeyError(f"Passage ID not found: {passage_id}")
def filter_search_results(
self,
search_results: list[SearchResult],
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
) -> list[SearchResult]:
"""
Apply metadata filters to search results.
Args:
search_results: List of SearchResult objects
metadata_filters: Filter specifications to apply
Returns:
Filtered list of SearchResult objects
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying metadata filters to {len(search_results)} results")
# Convert SearchResult objects to dictionaries for the filter engine
result_dicts = []
for result in search_results:
result_dicts.append(
{
"id": result.id,
"score": result.score,
"text": result.text,
"metadata": result.metadata,
}
)
# Apply filters using the filter engine
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
# Convert back to SearchResult objects
filtered_results = []
for result_dict in filtered_dicts:
filtered_results.append(
SearchResult(
id=result_dict["id"],
score=result_dict["score"],
text=result_dict["text"],
metadata=result_dict["metadata"],
)
)
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
return filtered_results
def __len__(self) -> int:
return self._total_count
@@ -223,6 +280,7 @@ class LeannBuilder:
embedding_model: str = "facebook/contriever",
dimensions: Optional[int] = None,
embedding_mode: str = "sentence-transformers",
embedding_options: Optional[dict[str, Any]] = None,
**backend_kwargs,
):
self.backend_name = backend_name
@@ -245,6 +303,7 @@ class LeannBuilder:
self.embedding_model = embedding_model
self.dimensions = dimensions
self.embedding_mode = embedding_mode
self.embedding_options = embedding_options or {}
# Check if we need to use cosine distance for normalized embeddings
normalized_embeddings_models = {
@@ -352,6 +411,7 @@ class LeannBuilder:
self.embedding_model,
self.embedding_mode,
use_server=False,
provider_options=self.embedding_options,
)[0]
)
path = Path(index_path)
@@ -391,6 +451,7 @@ class LeannBuilder:
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
string_ids = [chunk["id"] for chunk in self.chunks]
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
@@ -417,14 +478,15 @@ class LeannBuilder:
],
}
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = (
is_compact and is_recompute
) # Pruned only if compact and recompute
meta_data["is_pruned"] = bool(is_recompute)
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
@@ -539,18 +601,166 @@ class LeannBuilder:
"embeddings_source": str(embeddings_file),
}
if self.embedding_options:
meta_data["embedding_options"] = self.embedding_options
# Add storage status flags for HNSW backend
if self.backend_name == "hnsw":
is_compact = self.backend_kwargs.get("is_compact", True)
is_recompute = self.backend_kwargs.get("is_recompute", True)
meta_data["is_compact"] = is_compact
meta_data["is_pruned"] = is_compact and is_recompute
meta_data["is_pruned"] = bool(is_recompute)
with open(leann_meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
def update_index(self, index_path: str):
"""Append new passages and vectors to an existing HNSW index."""
if not self.chunks:
raise ValueError("No new chunks provided for update.")
path = Path(index_path)
index_dir = path.parent
index_name = path.name
index_prefix = path.stem
meta_path = index_dir / f"{index_name}.meta.json"
passages_file = index_dir / f"{index_name}.passages.jsonl"
offset_file = index_dir / f"{index_name}.passages.idx"
index_file = index_dir / f"{index_prefix}.index"
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
if not index_file.exists():
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
with open(meta_path, encoding="utf-8") as f:
meta = json.load(f)
backend_name = meta.get("backend_name")
if backend_name != self.backend_name:
raise ValueError(
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
)
meta_backend_kwargs = meta.get("backend_kwargs", {})
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
if index_is_compact:
raise ValueError(
"Compact HNSW indices do not support in-place updates. Rebuild required."
)
distance_metric = meta_backend_kwargs.get(
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
).lower()
needs_recompute = bool(
meta.get("is_pruned")
or meta_backend_kwargs.get("is_recompute")
or self.backend_kwargs.get("is_recompute")
)
with open(offset_file, "rb") as f:
offset_map: dict[str, int] = pickle.load(f)
existing_ids = set(offset_map.keys())
valid_chunks: list[dict[str, Any]] = []
for chunk in self.chunks:
text = chunk.get("text", "")
if not isinstance(text, str) or not text.strip():
continue
metadata = chunk.setdefault("metadata", {})
passage_id = chunk.get("id") or metadata.get("id")
if passage_id and passage_id in existing_ids:
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
valid_chunks.append(chunk)
if not valid_chunks:
raise ValueError("No valid chunks to append.")
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
embeddings = compute_embeddings(
texts_to_embed,
self.embedding_model,
self.embedding_mode,
use_server=False,
is_build=True,
provider_options=self.embedding_options,
)
embedding_dim = embeddings.shape[1]
expected_dim = meta.get("dimensions")
if expected_dim is not None and expected_dim != embedding_dim:
raise ValueError(
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
)
from leann_backend_hnsw import faiss # type: ignore
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
if distance_metric == "cosine":
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1
embeddings = embeddings / norms
index = faiss.read_index(str(index_file))
if hasattr(index, "is_recompute"):
index.is_recompute = needs_recompute
if getattr(index, "storage", None) is None:
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
storage_index = faiss.IndexFlatIP(index.d)
else:
storage_index = faiss.IndexFlatL2(index.d)
index.storage = storage_index
index.own_fields = True
if index.d != embedding_dim:
raise ValueError(
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
)
base_id = index.ntotal
for offset, chunk in enumerate(valid_chunks):
new_id = str(base_id + offset)
chunk.setdefault("metadata", {})["id"] = new_id
chunk["id"] = new_id
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
faiss.write_index(index, str(index_file))
with open(passages_file, "a", encoding="utf-8") as f:
for chunk in valid_chunks:
offset = f.tell()
json.dump(
{
"id": chunk["id"],
"text": chunk["text"],
"metadata": chunk.get("metadata", {}),
},
f,
ensure_ascii=False,
)
f.write("\n")
offset_map[chunk["id"]] = offset
with open(offset_file, "wb") as f:
pickle.dump(offset_map, f)
meta["total_passages"] = len(offset_map)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
logger.info(
"Appended %d passages to index '%s'. New total: %d",
len(valid_chunks),
index_path,
len(offset_map),
)
self.chunks.clear()
if needs_recompute:
prune_hnsw_embeddings_inplace(str(index_file))
class LeannSearcher:
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
@@ -574,15 +784,20 @@ class LeannSearcher:
self.embedding_model = self.meta_data["embedding_model"]
# Support both old and new format
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta_data.get("embedding_options", {})
# Delegate portability handling to PassageManager
self.passage_manager = PassageManager(
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
)
# Preserve backend name for conditional parameter forwarding
self.backend_name = backend_name
backend_factory = BACKEND_REGISTRY.get(backend_name)
if backend_factory is None:
raise ValueError(f"Backend '{backend_name}' not found.")
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
final_kwargs["enable_warmup"] = enable_warmup
if self.embedding_options:
final_kwargs.setdefault("embedding_options", self.embedding_options)
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
index_path, **final_kwargs
)
@@ -597,11 +812,43 @@ class LeannSearcher:
recompute_embeddings: bool = True,
pruning_strategy: Literal["global", "local", "proportional"] = "global",
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
**kwargs,
) -> list[SearchResult]:
"""
Search for nearest neighbors with optional metadata filtering.
Args:
query: Text query to search for
top_k: Number of nearest neighbors to return
complexity: Search complexity/candidate list size, higher = more accurate but slower
beam_width: Number of parallel search paths/IO requests per iteration
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
expected_zmq_port: ZMQ port for embedding server communication
metadata_filters: Optional filters to apply to search results based on metadata.
Format: {"field_name": {"operator": value}}
Supported operators:
- Comparison: "==", "!=", "<", "<=", ">", ">="
- Membership: "in", "not_in"
- String: "contains", "starts_with", "ends_with"
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
**kwargs: Backend-specific parameters
Returns:
List of SearchResult objects with text, metadata, and similarity scores
"""
# Handle grep search
if use_grep:
return self._grep_search(query, top_k)
logger.info("🔍 LeannSearcher.search() called:")
logger.info(f" Query: '{query}'")
logger.info(f" Top_k: {top_k}")
logger.info(f" Metadata filters: {metadata_filters}")
logger.info(f" Additional kwargs: {kwargs}")
# Smart top_k detection and adjustment
@@ -636,23 +883,33 @@ class LeannSearcher:
use_server_if_available=recompute_embeddings,
zmq_port=zmq_port,
)
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
# time.time() - start_time
# logger.info(f" Embedding time: {embedding_time} seconds")
logger.info(f" Generated embedding shape: {query_embedding.shape}")
embedding_time = time.time() - start_time
logger.info(f" Embedding time: {embedding_time} seconds")
start_time = time.time()
backend_search_kwargs: dict[str, Any] = {
"complexity": complexity,
"beam_width": beam_width,
"prune_ratio": prune_ratio,
"recompute_embeddings": recompute_embeddings,
"pruning_strategy": pruning_strategy,
"zmq_port": zmq_port,
}
# Only HNSW supports batching; forward conditionally
if self.backend_name == "hnsw":
backend_search_kwargs["batch_size"] = batch_size
# Merge any extra kwargs last
backend_search_kwargs.update(kwargs)
results = self.backend_impl.search(
query_embedding,
top_k,
complexity=complexity,
beam_width=beam_width,
prune_ratio=prune_ratio,
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
zmq_port=zmq_port,
**kwargs,
**backend_search_kwargs,
)
# logger.info(f" Search time: {search_time} seconds")
search_time = time.time() - start_time
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
enriched_results = []
@@ -691,15 +948,109 @@ class LeannSearcher:
f" {RED}{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
)
# Apply metadata filters if specified
if metadata_filters:
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
enriched_results = self.passage_manager.filter_search_results(
enriched_results, metadata_filters
)
# Define color codes outside the loop for final message
GREEN = "\033[92m"
RESET = "\033[0m"
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
return enriched_results
def _find_jsonl_file(self) -> Optional[str]:
"""Find the .jsonl file containing raw passages for grep search"""
index_path = Path(self.meta_path_str).parent
potential_files = [
index_path / "documents.leann.passages.jsonl",
index_path.parent / "documents.leann.passages.jsonl",
]
for file_path in potential_files:
if file_path.exists():
return str(file_path)
return None
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Perform grep-based search on raw passages"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl passages file found for grep search")
try:
cmd = ["grep", "-i", "-n", query, jsonl_file]
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 1:
return []
elif result.returncode != 0:
raise RuntimeError(f"Grep failed: {result.stderr}")
matches = []
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split(":", 1)
if len(parts) != 2:
continue
try:
data = json.loads(parts[1])
text = data.get("text", "")
score = text.lower().count(query.lower())
matches.append(
SearchResult(
id=data.get("id", parts[0]),
text=text,
metadata=data.get("metadata", {}),
score=float(score),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
except FileNotFoundError:
raise RuntimeError(
"grep command not found. Please install grep or use semantic search."
)
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Fallback regex search"""
jsonl_file = self._find_jsonl_file()
if not jsonl_file:
raise FileNotFoundError("No .jsonl file found")
pattern = re.compile(re.escape(query), re.IGNORECASE)
matches = []
with open(jsonl_file, encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
if pattern.search(line):
try:
data = json.loads(line.strip())
matches.append(
SearchResult(
id=data.get("id", str(line_num)),
text=data.get("text", ""),
metadata=data.get("metadata", {}),
score=float(len(pattern.findall(data.get("text", "")))),
)
)
except json.JSONDecodeError:
continue
matches.sort(key=lambda x: x.score, reverse=True)
return matches[:top_k]
def cleanup(self):
"""Explicitly cleanup embedding server resources.
This method should be called after you're done using the searcher,
especially in test environments or batch processing scenarios.
"""
@@ -731,9 +1082,15 @@ class LeannChat:
index_path: str,
llm_config: Optional[dict[str, Any]] = None,
enable_warmup: bool = False,
searcher: Optional[LeannSearcher] = None,
**kwargs,
):
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
if searcher is None:
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
self._owns_searcher = True
else:
self.searcher = searcher
self._owns_searcher = False
self.llm = get_llm(llm_config)
def ask(
@@ -747,6 +1104,9 @@ class LeannChat:
pruning_strategy: Literal["global", "local", "proportional"] = "global",
llm_kwargs: Optional[dict[str, Any]] = None,
expected_zmq_port: int = 5557,
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
batch_size: int = 0,
use_grep: bool = False,
**search_kwargs,
):
if llm_kwargs is None:
@@ -761,10 +1121,12 @@ class LeannChat:
recompute_embeddings=recompute_embeddings,
pruning_strategy=pruning_strategy,
expected_zmq_port=expected_zmq_port,
metadata_filters=metadata_filters,
batch_size=batch_size,
**search_kwargs,
)
search_time = time.time() - search_time
# logger.info(f" Search time: {search_time} seconds")
logger.info(f" Search time: {search_time} seconds")
context = "\n\n".join([r.text for r in results])
prompt = (
"Here is some retrieved context that might help answer your question:\n\n"
@@ -800,7 +1162,9 @@ class LeannChat:
This method should be called after you're done using the chat interface,
especially in test environments or batch processing scenarios.
"""
if hasattr(self.searcher, "cleanup"):
# Only stop the embedding server if this LeannChat instance created the searcher.
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
self.searcher.cleanup()
# Enable automatic cleanup patterns

View File

@@ -12,6 +12,8 @@ from typing import Any, Optional
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
def validate_model_and_suggest(
model_name: str, llm_type: str, host: str = "http://localhost:11434"
model_name: str, llm_type: str, host: Optional[str] = None
) -> Optional[str]:
"""Validate model name and provide suggestions if invalid"""
if llm_type == "ollama":
available_models = check_ollama_models(host)
resolved_host = resolve_ollama_host(host)
available_models = check_ollama_models(resolved_host)
if available_models and model_name not in available_models:
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
class OllamaChat(LLMInterface):
"""LLM interface for Ollama models."""
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
self.model = model
self.host = host
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
self.host = resolve_ollama_host(host)
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
try:
import requests
# Check if the Ollama server is responsive
if host:
requests.get(host)
if self.host:
requests.get(self.host)
# Pre-check model availability with helpful suggestions
model_error = validate_model_and_suggest(model, "ollama", host)
model_error = validate_model_and_suggest(model, "ollama", self.host)
if model_error:
raise ValueError(model_error)
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
)
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
logger.error(
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
raise ConnectionError(
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
)
def ask(self, prompt: str, **kwargs) -> str:
@@ -737,21 +742,31 @@ class GeminiChat(LLMInterface):
class OpenAIChat(LLMInterface):
"""LLM interface for OpenAI models."""
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
def __init__(
self,
model: str = "gpt-4o",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
):
self.model = model
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.base_url = resolve_openai_base_url(base_url)
self.api_key = resolve_openai_api_key(api_key)
if not self.api_key:
raise ValueError(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
)
logger.info(f"Initializing OpenAI Chat with model='{model}'")
logger.info(
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
model,
self.base_url,
)
try:
import openai
self.client = openai.OpenAI(api_key=self.api_key)
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
except ImportError:
raise ImportError(
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
@@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
if llm_type == "ollama":
return OllamaChat(
model=model or "llama3:8b",
host=llm_config.get("host", "http://localhost:11434"),
host=llm_config.get("host"),
)
elif llm_type == "hf":
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
elif llm_type == "openai":
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
return OpenAIChat(
model=model or "gpt-4o",
api_key=llm_config.get("api_key"),
base_url=llm_config.get("base_url"),
)
elif llm_type == "gemini":
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
elif llm_type == "simulated":

View File

@@ -1,6 +1,6 @@
"""
Enhanced chunking utilities with AST-aware code chunking support.
Provides unified interface for both traditional and AST-based text chunking.
Packaged within leann-core so installed wheels can import it reliably.
"""
import logging
@@ -22,30 +22,9 @@ CODE_EXTENSIONS = {
".jsx": "typescript",
}
# Default chunk parameters for different content types
DEFAULT_CHUNK_PARAMS = {
"code": {
"max_chunk_size": 512,
"chunk_overlap": 64,
},
"text": {
"chunk_size": 256,
"chunk_overlap": 128,
},
}
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
"""
Separate documents into code files and regular text files.
Args:
documents: List of LlamaIndex Document objects
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
Returns:
Tuple of (code_documents, text_documents)
"""
"""Separate documents into code files and regular text files."""
if code_extensions is None:
code_extensions = CODE_EXTENSIONS
@@ -53,16 +32,10 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
text_docs = []
for doc in documents:
# Get file path from metadata
file_path = doc.metadata.get("file_path", "")
if not file_path:
# Fallback to file_name
file_path = doc.metadata.get("file_name", "")
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
if file_path:
file_ext = Path(file_path).suffix.lower()
if file_ext in code_extensions:
# Add language info to metadata
doc.metadata["language"] = code_extensions[file_ext]
doc.metadata["is_code"] = True
code_docs.append(doc)
@@ -70,7 +43,6 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
doc.metadata["is_code"] = False
text_docs.append(doc)
else:
# If no file path, treat as text
doc.metadata["is_code"] = False
text_docs.append(doc)
@@ -79,7 +51,7 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
def get_language_from_extension(file_path: str) -> Optional[str]:
"""Get the programming language from file extension."""
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
ext = Path(file_path).suffix.lower()
return CODE_EXTENSIONS.get(ext)
@@ -90,40 +62,26 @@ def create_ast_chunks(
chunk_overlap: int = 64,
metadata_template: str = "default",
) -> list[str]:
"""
Create AST-aware chunks from code documents using astchunk.
"""Create AST-aware chunks from code documents using astchunk.
Args:
documents: List of code documents
max_chunk_size: Maximum characters per chunk
chunk_overlap: Number of AST nodes to overlap between chunks
metadata_template: Template for chunk metadata
Returns:
List of text chunks with preserved code structure
Falls back to traditional chunking if astchunk is unavailable.
"""
try:
from astchunk import ASTChunkBuilder
from astchunk import ASTChunkBuilder # optional dependency
except ImportError as e:
logger.error(f"astchunk not available: {e}")
logger.info("Falling back to traditional chunking for code files")
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
all_chunks = []
for doc in documents:
# Get language from metadata (set by detect_code_files)
language = doc.metadata.get("language")
if not language:
logger.warning(
"No language detected for document, falling back to traditional chunking"
)
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
all_chunks.extend(traditional_chunks)
logger.warning("No language detected; falling back to traditional chunking")
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
continue
try:
# Configure astchunk
configs = {
"max_chunk_size": max_chunk_size,
"language": language,
@@ -131,7 +89,6 @@ def create_ast_chunks(
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
}
# Add repository-level metadata if available
repo_metadata = {
"file_path": doc.metadata.get("file_path", ""),
"file_name": doc.metadata.get("file_name", ""),
@@ -140,17 +97,13 @@ def create_ast_chunks(
}
configs["repo_level_metadata"] = repo_metadata
# Create chunk builder and process
chunk_builder = ASTChunkBuilder(**configs)
code_content = doc.get_content()
if not code_content or not code_content.strip():
logger.warning("Empty code content, skipping")
continue
chunks = chunk_builder.chunkify(code_content)
# Extract text content from chunks
for chunk in chunks:
if hasattr(chunk, "text"):
chunk_text = chunk.text
@@ -159,7 +112,6 @@ def create_ast_chunks(
elif isinstance(chunk, str):
chunk_text = chunk
else:
# Try to convert to string
chunk_text = str(chunk)
if chunk_text and chunk_text.strip():
@@ -168,12 +120,10 @@ def create_ast_chunks(
logger.info(
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
)
except Exception as e:
logger.warning(f"AST chunking failed for {language} file: {e}")
logger.info("Falling back to traditional chunking")
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
all_chunks.extend(traditional_chunks)
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
return all_chunks
@@ -181,23 +131,10 @@ def create_ast_chunks(
def create_traditional_chunks(
documents, chunk_size: int = 256, chunk_overlap: int = 128
) -> list[str]:
"""
Create traditional text chunks using LlamaIndex SentenceSplitter.
Args:
documents: List of documents to chunk
chunk_size: Size of each chunk in characters
chunk_overlap: Overlap between chunks
Returns:
List of text chunks
"""
# Handle invalid chunk_size values
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
if chunk_size <= 0:
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
chunk_size = 256
# Ensure chunk_overlap is not negative and not larger than chunk_size
if chunk_overlap < 0:
chunk_overlap = 0
if chunk_overlap >= chunk_size:
@@ -215,12 +152,9 @@ def create_traditional_chunks(
try:
nodes = node_parser.get_nodes_from_documents([doc])
if nodes:
chunk_texts = [node.get_content() for node in nodes]
all_texts.extend(chunk_texts)
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
all_texts.extend(node.get_content() for node in nodes)
except Exception as e:
logger.error(f"Traditional chunking failed for document: {e}")
# As last resort, add the raw content
content = doc.get_content()
if content and content.strip():
all_texts.append(content.strip())
@@ -238,32 +172,13 @@ def create_text_chunks(
code_file_extensions: Optional[list[str]] = None,
ast_fallback_traditional: bool = True,
) -> list[str]:
"""
Create text chunks from documents with optional AST support for code files.
Args:
documents: List of LlamaIndex Document objects
chunk_size: Size for traditional text chunks
chunk_overlap: Overlap for traditional text chunks
use_ast_chunking: Whether to use AST chunking for code files
ast_chunk_size: Size for AST chunks
ast_chunk_overlap: Overlap for AST chunks
code_file_extensions: Custom list of code file extensions
ast_fallback_traditional: Fall back to traditional chunking on AST errors
Returns:
List of text chunks
"""
"""Create text chunks from documents with optional AST support for code files."""
if not documents:
logger.warning("No documents provided for chunking")
return []
# Create a local copy of supported extensions for this function call
local_code_extensions = CODE_EXTENSIONS.copy()
# Update supported extensions if provided
if code_file_extensions:
# Map extensions to languages (simplified mapping)
ext_mapping = {
".py": "python",
".java": "java",
@@ -273,47 +188,32 @@ def create_text_chunks(
}
for ext in code_file_extensions:
if ext.lower() not in local_code_extensions:
# Try to guess language from extension
if ext.lower() in ext_mapping:
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
else:
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
all_chunks = []
if use_ast_chunking:
# Separate code and text documents using local extensions
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
# Process code files with AST chunking
if code_docs:
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
try:
ast_chunks = create_ast_chunks(
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
all_chunks.extend(
create_ast_chunks(
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
)
)
all_chunks.extend(ast_chunks)
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
except Exception as e:
logger.error(f"AST chunking failed: {e}")
if ast_fallback_traditional:
logger.info("Falling back to traditional chunking for code files")
traditional_code_chunks = create_traditional_chunks(
code_docs, chunk_size, chunk_overlap
all_chunks.extend(
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
)
all_chunks.extend(traditional_code_chunks)
else:
raise
# Process text files with traditional chunking
if text_docs:
logger.info(f"Processing {len(text_docs)} text files with traditional chunking")
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
all_chunks.extend(text_chunks)
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
else:
# Use traditional chunking for all files
logger.info(f"Processing {len(documents)} documents with traditional chunking")
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
logger.info(f"Total chunks created: {len(all_chunks)}")

View File

@@ -1,6 +1,5 @@
import argparse
import asyncio
import sys
from pathlib import Path
from typing import Any, Optional, Union
@@ -10,6 +9,7 @@ from tqdm import tqdm
from .api import LeannBuilder, LeannChat, LeannSearcher
from .registry import register_project_directory
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
@@ -124,6 +124,24 @@ Examples:
choices=["sentence-transformers", "openai", "mlx", "ollama"],
help="Embedding backend mode (default: sentence-transformers)",
)
build_parser.add_argument(
"--embedding-host",
type=str,
default=None,
help="Override Ollama-compatible embedding host",
)
build_parser.add_argument(
"--embedding-api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible embedding services",
)
build_parser.add_argument(
"--embedding-api-key",
type=str,
default=None,
help="API key for embedding service (defaults to OPENAI_API_KEY)",
)
build_parser.add_argument(
"--force", "-f", action="store_true", help="Force rebuild existing index"
)
@@ -239,6 +257,11 @@ Examples:
# Ask command
ask_parser = subparsers.add_parser("ask", help="Ask questions")
ask_parser.add_argument("index_name", help="Index name")
ask_parser.add_argument(
"query",
nargs="?",
help="Question to ask (omit for prompt or when using --interactive)",
)
ask_parser.add_argument(
"--llm",
type=str,
@@ -249,7 +272,12 @@ Examples:
ask_parser.add_argument(
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
)
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
ask_parser.add_argument(
"--host",
type=str,
default=None,
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
)
ask_parser.add_argument(
"--interactive", "-i", action="store_true", help="Interactive chat mode"
)
@@ -278,6 +306,18 @@ Examples:
default=None,
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
)
ask_parser.add_argument(
"--api-base",
type=str,
default=None,
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
)
ask_parser.add_argument(
"--api-key",
type=str,
default=None,
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
)
# List command
subparsers.add_parser("list", help="List all indexes")
@@ -322,9 +362,17 @@ Examples:
return basic_matches
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
"""Check if a file should be excluded using gitignore parser."""
return gitignore_matches(str(relative_path))
def _should_exclude_file(self, file_path: Path, gitignore_matches) -> bool:
"""Check if a file should be excluded using gitignore parser.
Always match against absolute, posix-style paths for consistency with
gitignore_parser expectations.
"""
try:
absolute_path = file_path.resolve()
except Exception:
absolute_path = Path(str(file_path))
return gitignore_matches(absolute_path.as_posix())
def _is_git_submodule(self, path: Path) -> bool:
"""Check if a path is a git submodule."""
@@ -396,7 +444,9 @@ Examples:
print(f" {current_path}")
print(" " + "" * 45)
current_indexes = self._discover_indexes_in_project(current_path)
current_indexes = self._discover_indexes_in_project(
current_path, exclude_dirs=other_projects
)
if current_indexes:
for idx in current_indexes:
total_indexes += 1
@@ -435,9 +485,14 @@ Examples:
print(" leann build my-docs --docs ./documents")
else:
# Count only projects that have at least one discoverable index
projects_count = sum(
1 for p in valid_projects if len(self._discover_indexes_in_project(p)) > 0
)
projects_count = 0
for p in valid_projects:
if p == current_path:
discovered = self._discover_indexes_in_project(p, exclude_dirs=other_projects)
else:
discovered = self._discover_indexes_in_project(p)
if len(discovered) > 0:
projects_count += 1
print(f"📊 Total: {total_indexes} indexes across {projects_count} projects")
if current_indexes_count > 0:
@@ -454,9 +509,22 @@ Examples:
print("\n💡 Create your first index:")
print(" leann build my-docs --docs ./documents")
def _discover_indexes_in_project(self, project_path: Path):
"""Discover all indexes in a project directory (both CLI and apps formats)"""
def _discover_indexes_in_project(
self, project_path: Path, exclude_dirs: Optional[list[Path]] = None
):
"""Discover all indexes in a project directory (both CLI and apps formats)
exclude_dirs: when provided, skip any APP-format index files that are
located under these directories. This prevents duplicates when the
current project is a parent directory of other registered projects.
"""
indexes = []
exclude_dirs = exclude_dirs or []
# normalize to resolved paths once for comparison
try:
exclude_dirs_resolved = [p.resolve() for p in exclude_dirs]
except Exception:
exclude_dirs_resolved = exclude_dirs
# 1. CLI format: .leann/indexes/index_name/
cli_indexes_dir = project_path / ".leann" / "indexes"
@@ -495,6 +563,17 @@ Examples:
continue
except Exception:
pass
# Skip meta files that live under excluded directories
try:
meta_parent_resolved = meta_file.parent.resolve()
if any(
meta_parent_resolved.is_relative_to(ex_dir)
for ex_dir in exclude_dirs_resolved
):
continue
except Exception:
# best effort; if resolve or comparison fails, do not exclude
pass
# Use the parent directory name as the app index display name
display_name = meta_file.parent.name
# Extract file base used to store files
@@ -1022,7 +1101,8 @@ Examples:
# Try to use better PDF parsers first, but only if PDFs are requested
documents = []
docs_path = Path(docs_dir)
# Use resolved absolute paths to avoid mismatches (symlinks, relative vs absolute)
docs_path = Path(docs_dir).resolve()
# Check if we should process PDFs
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
@@ -1031,10 +1111,15 @@ Examples:
for file_path in docs_path.rglob("*.pdf"):
# Check if file matches any exclude pattern
try:
# Ensure both paths are resolved before computing relativity
file_path_resolved = file_path.resolve()
# Determine directory scope using the non-resolved path to avoid
# misclassifying symlinked entries as outside the docs directory
relative_path = file_path.relative_to(docs_path)
if not include_hidden and _path_has_hidden_segment(relative_path):
continue
if self._should_exclude_file(relative_path, gitignore_matches):
# Use absolute path for gitignore matching
if self._should_exclude_file(file_path_resolved, gitignore_matches):
continue
except ValueError:
# Skip files that can't be made relative to docs_path
@@ -1077,10 +1162,11 @@ Examples:
) -> bool:
"""Return True if file should be included (not excluded)"""
try:
docs_path_obj = Path(docs_dir)
file_path_obj = Path(file_path)
relative_path = file_path_obj.relative_to(docs_path_obj)
return not self._should_exclude_file(relative_path, gitignore_matches)
docs_path_obj = Path(docs_dir).resolve()
file_path_obj = Path(file_path).resolve()
# Use absolute path for gitignore matching
_ = file_path_obj.relative_to(docs_path_obj) # validate scope
return not self._should_exclude_file(file_path_obj, gitignore_matches)
except (ValueError, OSError):
return True # Include files that can't be processed
@@ -1170,13 +1256,8 @@ Examples:
if use_ast:
print("🧠 Using AST-aware chunking for code files")
try:
# Import enhanced chunking utilities
# Add apps directory to path to import chunking utilities
apps_dir = Path(__file__).parent.parent.parent.parent.parent / "apps"
if apps_dir.exists():
sys.path.insert(0, str(apps_dir))
from chunking import create_text_chunks
# Import enhanced chunking utilities from packaged module
from .chunking_utils import create_text_chunks
# Use enhanced chunking with AST support
all_texts = create_text_chunks(
@@ -1191,7 +1272,9 @@ Examples:
)
except ImportError as e:
print(f"⚠️ AST chunking not available ({e}), falling back to traditional chunking")
print(
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
)
use_ast = False
if not use_ast:
@@ -1283,10 +1366,20 @@ Examples:
print(f"Building index '{index_name}' with {args.backend} backend...")
embedding_options: dict[str, Any] = {}
if args.embedding_mode == "ollama":
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
elif args.embedding_mode == "openai":
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
if resolved_embedding_key:
embedding_options["api_key"] = resolved_embedding_key
builder = LeannBuilder(
backend_name=args.backend,
embedding_model=args.embedding_model,
embedding_mode=args.embedding_mode,
embedding_options=embedding_options or None,
graph_degree=args.graph_degree,
complexity=args.complexity,
is_compact=args.compact,
@@ -1434,11 +1527,38 @@ Examples:
llm_config = {"type": args.llm, "model": args.model}
if args.llm == "ollama":
llm_config["host"] = args.host
llm_config["host"] = resolve_ollama_host(args.host)
elif args.llm == "openai":
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
resolved_api_key = resolve_openai_api_key(args.api_key)
if resolved_api_key:
llm_config["api_key"] = resolved_api_key
chat = LeannChat(index_path=index_path, llm_config=llm_config)
llm_kwargs: dict[str, Any] = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
def _ask_once(prompt: str) -> None:
response = chat.ask(
prompt,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
initial_query = (args.query or "").strip()
if args.interactive:
if initial_query:
_ask_once(initial_query)
print("LEANN Assistant ready! Type 'quit' to exit")
print("=" * 40)
@@ -1451,41 +1571,14 @@ Examples:
if not user_input:
continue
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
response = chat.ask(
user_input,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
_ask_once(user_input)
else:
query = input("Enter your question: ").strip()
if query:
# Prepare LLM kwargs with thinking budget if specified
llm_kwargs = {}
if args.thinking_budget:
llm_kwargs["thinking_budget"] = args.thinking_budget
query = initial_query or input("Enter your question: ").strip()
if not query:
print("No question provided. Exiting.")
return
response = chat.ask(
query,
top_k=args.top_k,
complexity=args.complexity,
beam_width=args.beam_width,
prune_ratio=args.prune_ratio,
recompute_embeddings=args.recompute_embeddings,
pruning_strategy=args.pruning_strategy,
llm_kwargs=llm_kwargs,
)
print(f"LEANN: {response}")
_ask_once(query)
async def run(self, args=None):
parser = self.create_parser()

View File

@@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance
import logging
import os
from typing import Any
import time
from typing import Any, Optional
import numpy as np
import torch
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
# Set up logger with proper level
logger = logging.getLogger(__name__)
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
@@ -28,6 +31,9 @@ def compute_embeddings(
is_build: bool = False,
batch_size: int = 32,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
provider_options: Optional[dict[str, Any]] = None,
) -> np.ndarray:
"""
Unified embedding computation entry point
@@ -43,6 +49,8 @@ def compute_embeddings(
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
"""
provider_options = provider_options or {}
if mode == "sentence-transformers":
return compute_embeddings_sentence_transformers(
texts,
@@ -50,13 +58,25 @@ def compute_embeddings(
is_build=is_build,
batch_size=batch_size,
adaptive_optimization=adaptive_optimization,
manual_tokenize=manual_tokenize,
max_length=max_length,
)
elif mode == "openai":
return compute_embeddings_openai(texts, model_name)
return compute_embeddings_openai(
texts,
model_name,
base_url=provider_options.get("base_url"),
api_key=provider_options.get("api_key"),
)
elif mode == "mlx":
return compute_embeddings_mlx(texts, model_name)
elif mode == "ollama":
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
return compute_embeddings_ollama(
texts,
model_name,
is_build=is_build,
host=provider_options.get("host"),
)
elif mode == "gemini":
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
else:
@@ -71,6 +91,8 @@ def compute_embeddings_sentence_transformers(
batch_size: int = 32,
is_build: bool = False,
adaptive_optimization: bool = True,
manual_tokenize: bool = False,
max_length: int = 512,
) -> np.ndarray:
"""
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
@@ -214,20 +236,130 @@ def compute_embeddings_sentence_transformers(
logger.info(f"Model cached: {cache_key}")
# Compute embeddings with optimized inference mode
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
logger.info(
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
)
# Use torch.inference_mode for optimal performance
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
start_time = time.time()
if not manual_tokenize:
# Use SentenceTransformer's optimized encode path (default)
with torch.inference_mode():
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=is_build, # Don't show progress bar in server environment
convert_to_numpy=True,
normalize_embeddings=False,
device=device,
)
# Synchronize if CUDA to measure accurate wall time
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
else:
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
try:
from transformers import AutoModel, AutoTokenizer # type: ignore
except Exception as e:
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
# Cache tokenizer and model
tok_cache_key = f"hf_tokenizer_{model_name}"
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
hf_tokenizer = _model_cache[tok_cache_key]
hf_model = _model_cache[mdl_cache_key]
logger.info("Using cached HF tokenizer/model for manual path")
else:
logger.info("Loading HF tokenizer/model for manual tokenization path")
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
hf_model.to(device)
hf_model.eval()
# Optional compile on supported devices
if device in ["cuda", "mps"]:
try:
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
except Exception:
pass
_model_cache[tok_cache_key] = hf_tokenizer
_model_cache[mdl_cache_key] = hf_model
all_embeddings: list[np.ndarray] = []
# Progress bar when building or for large inputs
show_progress = is_build or len(texts) > 32
try:
if show_progress:
from tqdm import tqdm # type: ignore
batch_iter = tqdm(
range(0, len(texts), batch_size),
desc="Embedding (manual)",
unit="batch",
)
else:
batch_iter = range(0, len(texts), batch_size)
except Exception:
batch_iter = range(0, len(texts), batch_size)
start_time_manual = time.time()
with torch.inference_mode():
for start_index in batch_iter:
end_index = min(start_index + batch_size, len(texts))
batch_texts = texts[start_index:end_index]
tokenize_start_time = time.time()
inputs = hf_tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
tokenize_end_time = time.time()
logger.info(
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
)
# Print shapes of all input tensors for debugging
for k, v in inputs.items():
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
to_device_start_time = time.time()
inputs = {k: v.to(device) for k, v in inputs.items()}
to_device_end_time = time.time()
logger.info(
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
)
forward_start_time = time.time()
outputs = hf_model(**inputs)
forward_end_time = time.time()
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
last_hidden_state = outputs.last_hidden_state # (B, L, H)
attention_mask = inputs.get("attention_mask")
if attention_mask is None:
# Fallback: assume all tokens are valid
pooled = last_hidden_state.mean(dim=1)
else:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
masked = last_hidden_state * mask
lengths = mask.sum(dim=1).clamp(min=1)
pooled = masked.sum(dim=1) / lengths
# Move to CPU float32
batch_embeddings = pooled.detach().to("cpu").float().numpy()
all_embeddings.append(batch_embeddings)
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
except Exception:
pass
end_time = time.time()
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
end_time = time.time()
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
logger.info(f"Time taken: {end_time - start_time} seconds")
# Validate results
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
@@ -236,12 +368,15 @@ def compute_embeddings_sentence_transformers(
return embeddings
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
def compute_embeddings_openai(
texts: list[str],
model_name: str,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> np.ndarray:
# TODO: @yichuan-w add progress bar only in build mode
"""Compute embeddings using OpenAI API"""
try:
import os
import openai
except ImportError as e:
raise ImportError(f"OpenAI package not installed: {e}")
@@ -256,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
)
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
resolved_base_url = resolve_openai_base_url(base_url)
resolved_api_key = resolve_openai_api_key(api_key)
if not resolved_api_key:
raise RuntimeError("OPENAI_API_KEY environment variable not set")
# Cache OpenAI client
cache_key = "openai_client"
cache_key = f"openai_client::{resolved_base_url}"
if cache_key in _model_cache:
client = _model_cache[cache_key]
else:
client = openai.OpenAI(api_key=api_key)
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
_model_cache[cache_key] = client
logger.info("OpenAI client cached")
@@ -390,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
def compute_embeddings_ollama(
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
texts: list[str],
model_name: str,
is_build: bool = False,
host: Optional[str] = None,
) -> np.ndarray:
"""
Compute embeddings using Ollama API with simplified batch processing.
@@ -401,7 +541,7 @@ def compute_embeddings_ollama(
texts: List of texts to compute embeddings for
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
is_build: Whether this is a build operation (shows progress bar)
host: Ollama host URL (default: http://localhost:11434)
host: Ollama host URL (defaults to environment or http://localhost:11434)
Returns:
Normalized embeddings array, shape: (len(texts), embedding_dim)
@@ -416,17 +556,19 @@ def compute_embeddings_ollama(
if not texts:
raise ValueError("Cannot compute embeddings for empty text list")
resolved_host = resolve_ollama_host(host)
logger.info(
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
)
# Check if Ollama is running
try:
response = requests.get(f"{host}/api/version", timeout=5)
response = requests.get(f"{resolved_host}/api/version", timeout=5)
response.raise_for_status()
except requests.exceptions.ConnectionError:
error_msg = (
f"❌ Could not connect to Ollama at {host}.\n\n"
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
"Please ensure Ollama is running:\n"
" • macOS/Linux: ollama serve\n"
" • Windows: Make sure Ollama is running in the system tray\n\n"
@@ -438,7 +580,7 @@ def compute_embeddings_ollama(
# Check if model exists and provide helpful suggestions
try:
response = requests.get(f"{host}/api/tags", timeout=5)
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
response.raise_for_status()
models = response.json()
model_names = [model["name"] for model in models.get("models", [])]
@@ -501,7 +643,9 @@ def compute_embeddings_ollama(
# Verify the model supports embeddings by testing it
try:
test_response = requests.post(
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": "test"},
timeout=10,
)
if test_response.status_code != 200:
error_msg = (
@@ -548,7 +692,7 @@ def compute_embeddings_ollama(
while retry_count < max_retries:
try:
response = requests.post(
f"{host}/api/embeddings",
f"{resolved_host}/api/embeddings",
json={"model": model_name, "prompt": truncated_text},
timeout=30,
)

View File

@@ -8,6 +8,8 @@ import time
from pathlib import Path
from typing import Optional
from .settings import encode_provider_options
# Lightweight, self-contained server manager with no cross-process inspection
# Set up logging based on environment variable
@@ -82,16 +84,40 @@ class EmbeddingServerManager:
) -> tuple[bool, int]:
"""Start the embedding server."""
# passages_file may be present in kwargs for server CLI, but we don't need it here
provider_options = kwargs.pop("provider_options", None)
config_signature = {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
# If this manager already has a live server, just reuse it
if self.server_process and self.server_process.poll() is None and self.server_port:
if (
self.server_process
and self.server_process.poll() is None
and self.server_port
and self._server_config == config_signature
):
logger.info("Reusing in-process server")
return True, self.server_port
# Configuration changed, stop existing server before starting a new one
if self.server_process and self.server_process.poll() is None:
logger.info("Existing server configuration differs; restarting embedding server")
self.stop_server()
# For Colab environment, use a different strategy
if _is_colab_environment():
logger.info("Detected Colab environment, using alternative startup strategy")
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
return self._start_server_colab(
port,
model_name,
embedding_mode,
provider_options=provider_options,
**kwargs,
)
# Always pick a fresh available port
try:
@@ -101,13 +127,21 @@ class EmbeddingServerManager:
return False, port
# Start a new server
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
return self._start_new_server(
actual_port,
model_name,
embedding_mode,
provider_options=provider_options,
config_signature=config_signature,
**kwargs,
)
def _start_server_colab(
self,
port: int,
model_name: str,
embedding_mode: str = "sentence-transformers",
provider_options: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
"""Start server with Colab-specific configuration."""
@@ -125,8 +159,20 @@ class EmbeddingServerManager:
try:
# In Colab, we'll use a more direct approach
self._launch_server_process_colab(command, actual_port)
return self._wait_for_server_ready_colab(actual_port)
self._launch_server_process_colab(
command,
actual_port,
provider_options=provider_options,
)
started, ready_port = self._wait_for_server_ready_colab(actual_port)
if started:
self._server_config = {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e:
logger.error(f"Failed to start embedding server in Colab: {e}")
return False, actual_port
@@ -134,7 +180,13 @@ class EmbeddingServerManager:
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
def _start_new_server(
self, port: int, model_name: str, embedding_mode: str, **kwargs
self,
port: int,
model_name: str,
embedding_mode: str,
provider_options: Optional[dict] = None,
config_signature: Optional[dict] = None,
**kwargs,
) -> tuple[bool, int]:
"""Start a new embedding server on the given port."""
logger.info(f"Starting embedding server on port {port}...")
@@ -142,8 +194,20 @@ class EmbeddingServerManager:
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
try:
self._launch_server_process(command, port)
return self._wait_for_server_ready(port)
self._launch_server_process(
command,
port,
provider_options=provider_options,
)
started, ready_port = self._wait_for_server_ready(port)
if started:
self._server_config = config_signature or {
"model_name": model_name,
"passages_file": kwargs.get("passages_file", ""),
"embedding_mode": embedding_mode,
"provider_options": provider_options or {},
}
return started, ready_port
except Exception as e:
logger.error(f"Failed to start embedding server: {e}")
return False, port
@@ -173,7 +237,12 @@ class EmbeddingServerManager:
return command
def _launch_server_process(self, command: list, port: int) -> None:
def _launch_server_process(
self,
command: list,
port: int,
provider_options: Optional[dict] = None,
) -> None:
"""Launch the server process."""
project_root = Path(__file__).parent.parent.parent.parent.parent
logger.info(f"Command: {' '.join(command)}")
@@ -193,14 +262,20 @@ class EmbeddingServerManager:
# Start embedding server subprocess
logger.info(f"Starting server process with command: {' '.join(command)}")
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen(
command,
cwd=project_root,
stdout=stdout_target,
stderr=stderr_target,
env=env,
)
self.server_port = port
# Record config for in-process reuse
# Record config for in-process reuse (best effort; refined later when ready)
try:
self._server_config = {
"model_name": command[command.index("--model-name") + 1]
@@ -212,12 +287,14 @@ class EmbeddingServerManager:
"embedding_mode": command[command.index("--embedding-mode") + 1]
if "--embedding-mode" in command
else "sentence-transformers",
"provider_options": provider_options or {},
}
except Exception:
self._server_config = {
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
logger.info(f"Server process started with PID: {self.server_process.pid}")
@@ -322,16 +399,27 @@ class EmbeddingServerManager:
# Removed: cross-process adoption no longer supported
return
def _launch_server_process_colab(self, command: list, port: int) -> None:
def _launch_server_process_colab(
self,
command: list,
port: int,
provider_options: Optional[dict] = None,
) -> None:
"""Launch the server process with Colab-specific settings."""
logger.info(f"Colab Command: {' '.join(command)}")
# In Colab, we need to be more careful about process management
env = os.environ.copy()
encoded_options = encode_provider_options(provider_options)
if encoded_options:
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
self.server_process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
env=env,
)
self.server_port = port
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
@@ -345,6 +433,7 @@ class EmbeddingServerManager:
"model_name": "",
"passages_file": "",
"embedding_mode": "sentence-transformers",
"provider_options": provider_options or {},
}
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:

View File

@@ -0,0 +1,240 @@
"""
Metadata filtering engine for LEANN search results.
This module provides generic metadata filtering capabilities that can be applied
to search results from any LEANN backend. The filtering supports various
operators for different data types including numbers, strings, booleans, and lists.
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
# Type alias for filter specifications
FilterValue = Union[str, int, float, bool, list]
FilterSpec = dict[str, FilterValue]
MetadataFilters = dict[str, FilterSpec]
class MetadataFilterEngine:
"""
Engine for evaluating metadata filters against search results.
Supports various operators for filtering based on metadata fields:
- Comparison: ==, !=, <, <=, >, >=
- Membership: in, not_in
- String operations: contains, starts_with, ends_with
- Boolean operations: is_true, is_false
"""
def __init__(self):
"""Initialize the filter engine with supported operators."""
self.operators = {
"==": self._equals,
"!=": self._not_equals,
"<": self._less_than,
"<=": self._less_than_or_equal,
">": self._greater_than,
">=": self._greater_than_or_equal,
"in": self._in,
"not_in": self._not_in,
"contains": self._contains,
"starts_with": self._starts_with,
"ends_with": self._ends_with,
"is_true": self._is_true,
"is_false": self._is_false,
}
def apply_filters(
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
) -> list[dict[str, Any]]:
"""
Apply metadata filters to a list of search results.
Args:
search_results: List of result dictionaries, each containing 'metadata' field
metadata_filters: Dictionary of filter specifications
Format: {"field_name": {"operator": value}}
Returns:
Filtered list of search results
"""
if not metadata_filters:
return search_results
logger.debug(f"Applying filters: {metadata_filters}")
logger.debug(f"Input results count: {len(search_results)}")
filtered_results = []
for result in search_results:
if self._evaluate_filters(result, metadata_filters):
filtered_results.append(result)
logger.debug(f"Filtered results count: {len(filtered_results)}")
return filtered_results
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
"""
Evaluate all filters against a single search result.
All filters must pass (AND logic) for the result to be included.
Args:
result: Full search result dictionary (including metadata, text, etc.)
filters: Filter specifications to evaluate
Returns:
True if all filters pass, False otherwise
"""
for field_name, filter_spec in filters.items():
if not self._evaluate_field_filter(result, field_name, filter_spec):
return False
return True
def _evaluate_field_filter(
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
) -> bool:
"""
Evaluate a single field filter against a search result.
Args:
result: Full search result dictionary
field_name: Name of the field to filter on
filter_spec: Filter specification for this field
Returns:
True if the filter passes, False otherwise
"""
# First check top-level fields, then check metadata
field_value = result.get(field_name)
if field_value is None:
# Try to get from metadata if not found at top level
metadata = result.get("metadata", {})
field_value = metadata.get(field_name)
# Handle missing fields - they fail all filters except existence checks
if field_value is None:
logger.debug(f"Field '{field_name}' not found in result or metadata")
return False
# Evaluate each operator in the filter spec
for operator, expected_value in filter_spec.items():
if operator not in self.operators:
logger.warning(f"Unsupported operator: {operator}")
return False
try:
if not self.operators[operator](field_value, expected_value):
logger.debug(
f"Filter failed: {field_name} {operator} {expected_value} "
f"(actual: {field_value})"
)
return False
except Exception as e:
logger.warning(
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
)
return False
return True
# Comparison operators
def _equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value equals expected value."""
return field_value == expected_value
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value does not equal expected value."""
return field_value != expected_value
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is less than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is greater than or equal to expected value."""
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
# Membership operators
def _in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'in' operator requires a list, tuple, or set")
return field_value in expected_value
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is not in the expected list/collection."""
if not isinstance(expected_value, (list, tuple, set)):
raise ValueError("'not_in' operator requires a list, tuple, or set")
return field_value not in expected_value
# String operators
def _contains(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value contains the expected substring."""
field_str = str(field_value)
expected_str = str(expected_value)
return expected_str in field_str
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value starts with the expected prefix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.startswith(expected_str)
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value ends with the expected suffix."""
field_str = str(field_value)
expected_str = str(expected_value)
return field_str.endswith(expected_str)
# Boolean operators
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is truthy."""
return bool(field_value)
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
"""Check if field value is falsy."""
return not bool(field_value)
# Helper methods
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
"""
Helper for numeric comparisons with type coercion.
Args:
field_value: Value from metadata
expected_value: Value to compare against
compare_func: Comparison function to apply
Returns:
Result of comparison
"""
try:
# Try to convert both values to numbers for comparison
if isinstance(field_value, str) and isinstance(expected_value, str):
# String comparison if both are strings
return compare_func(field_value, expected_value)
# Numeric comparison - attempt to convert to float
field_num = (
float(field_value) if not isinstance(field_value, (int, float)) else field_value
)
expected_num = (
float(expected_value)
if not isinstance(expected_value, (int, float))
else expected_value
)
return compare_func(field_num, expected_num)
except (ValueError, TypeError):
# Fall back to string comparison if numeric conversion fails
return compare_func(str(field_value), str(expected_value))

View File

@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
self.embedding_options = self.meta.get("embedding_options", {})
self.embedding_server_manager = EmbeddingServerManager(
backend_module_name=backend_module_name,
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
passages_file=passages_source_file,
distance_metric=distance_metric,
enable_warmup=kwargs.get("enable_warmup", False),
provider_options=self.embedding_options,
)
if not server_started:
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
from .embedding_compute import compute_embeddings
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
return compute_embeddings([query], self.embedding_model, embedding_mode)
return compute_embeddings(
[query],
self.embedding_model,
embedding_mode,
provider_options=self.embedding_options,
)
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
"""Compute embeddings using the ZMQ embedding server."""

View File

@@ -0,0 +1,74 @@
"""Runtime configuration helpers for LEANN."""
from __future__ import annotations
import json
import os
from typing import Any
# Default fallbacks to preserve current behaviour while keeping them in one place.
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
def _clean_url(value: str) -> str:
"""Normalize URL strings by stripping trailing slashes."""
return value.rstrip("/") if value else value
def resolve_ollama_host(explicit: str | None = None) -> str:
"""Resolve the Ollama-compatible endpoint to use."""
candidates = (
explicit,
os.getenv("LEANN_LOCAL_LLM_HOST"),
os.getenv("LEANN_OLLAMA_HOST"),
os.getenv("OLLAMA_HOST"),
os.getenv("LOCAL_LLM_ENDPOINT"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OLLAMA_HOST)
def resolve_openai_base_url(explicit: str | None = None) -> str:
"""Resolve the base URL for OpenAI-compatible services."""
candidates = (
explicit,
os.getenv("LEANN_OPENAI_BASE_URL"),
os.getenv("OPENAI_BASE_URL"),
os.getenv("LOCAL_OPENAI_BASE_URL"),
)
for candidate in candidates:
if candidate:
return _clean_url(candidate)
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
"""Resolve the API key for OpenAI-compatible services."""
if explicit:
return explicit
return os.getenv("OPENAI_API_KEY")
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
"""Serialize provider options for child processes."""
if not options:
return None
try:
return json.dumps(options)
except (TypeError, ValueError):
# Fall back to empty payload if serialization fails
return None

View File

@@ -2,6 +2,8 @@
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
For agent-facing discovery details, see `llms.txt` in the repository root.
## Prerequisites
Install LEANN globally for MCP integration (with default backend):

View File

@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "leann"
version = "0.3.2"
version = "0.3.4"
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
readme = "README.md"
requires-python = ">=3.9"

View File

@@ -99,6 +99,7 @@ wechat-exporter = "wechat_exporter.main:main"
leann-core = { path = "packages/leann-core", editable = true }
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
astchunk = { path = "packages/astchunk-leann", editable = true }
[tool.ruff]
target-version = "py39"

14
tests/test_cli_ask.py Normal file
View File

@@ -0,0 +1,14 @@
from leann.cli import LeannCLI
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
cli = LeannCLI()
parser = cli.create_parser()
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
assert args.command == "ask"
assert args.index_name == "my-docs"
assert args.query == "Where are prompts configured?"

View File

@@ -0,0 +1,365 @@
"""
Comprehensive tests for metadata filtering functionality.
This module tests the MetadataFilterEngine class and its integration
with the LEANN search system.
"""
import os
# Import the modules we're testing
import sys
from unittest.mock import Mock, patch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
from leann.api import PassageManager, SearchResult
from leann.metadata_filter import MetadataFilterEngine
class TestMetadataFilterEngine:
"""Test suite for the MetadataFilterEngine class."""
def setup_method(self):
"""Setup test fixtures."""
self.engine = MetadataFilterEngine()
# Sample search results for testing
self.sample_results = [
{
"id": "doc1",
"score": 0.95,
"text": "This is chapter 1 content",
"metadata": {
"chapter": 1,
"character": "Alice",
"tags": ["adventure", "fantasy"],
"word_count": 150,
"is_published": True,
"genre": "fiction",
},
},
{
"id": "doc2",
"score": 0.87,
"text": "This is chapter 3 content",
"metadata": {
"chapter": 3,
"character": "Bob",
"tags": ["mystery", "thriller"],
"word_count": 250,
"is_published": True,
"genre": "fiction",
},
},
{
"id": "doc3",
"score": 0.82,
"text": "This is chapter 5 content",
"metadata": {
"chapter": 5,
"character": "Alice",
"tags": ["romance", "drama"],
"word_count": 300,
"is_published": False,
"genre": "non-fiction",
},
},
{
"id": "doc4",
"score": 0.78,
"text": "This is chapter 10 content",
"metadata": {
"chapter": 10,
"character": "Charlie",
"tags": ["action", "adventure"],
"word_count": 400,
"is_published": True,
"genre": "fiction",
},
},
]
def test_engine_initialization(self):
"""Test that the filter engine initializes correctly."""
assert self.engine is not None
assert len(self.engine.operators) > 0
assert "==" in self.engine.operators
assert "contains" in self.engine.operators
assert "in" in self.engine.operators
def test_direct_instantiation(self):
"""Test direct instantiation of the engine."""
engine = MetadataFilterEngine()
assert isinstance(engine, MetadataFilterEngine)
def test_no_filters_returns_all_results(self):
"""Test that passing None or empty filters returns all results."""
# Test with None
result = self.engine.apply_filters(self.sample_results, None)
assert len(result) == len(self.sample_results)
# Test with empty dict
result = self.engine.apply_filters(self.sample_results, {})
assert len(result) == len(self.sample_results)
# Test comparison operators
def test_equals_filter(self):
"""Test equals (==) filter."""
filters = {"chapter": {"==": 1}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 1
assert result[0]["id"] == "doc1"
def test_not_equals_filter(self):
"""Test not equals (!=) filter."""
filters = {"genre": {"!=": "fiction"}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 1
assert result[0]["metadata"]["genre"] == "non-fiction"
def test_less_than_filter(self):
"""Test less than (<) filter."""
filters = {"chapter": {"<": 5}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 2
chapters = [r["metadata"]["chapter"] for r in result]
assert all(ch < 5 for ch in chapters)
def test_less_than_or_equal_filter(self):
"""Test less than or equal (<=) filter."""
filters = {"chapter": {"<=": 5}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 3
chapters = [r["metadata"]["chapter"] for r in result]
assert all(ch <= 5 for ch in chapters)
def test_greater_than_filter(self):
"""Test greater than (>) filter."""
filters = {"word_count": {">": 200}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 3 # Documents with word_count 250, 300, 400
word_counts = [r["metadata"]["word_count"] for r in result]
assert all(wc > 200 for wc in word_counts)
def test_greater_than_or_equal_filter(self):
"""Test greater than or equal (>=) filter."""
filters = {"word_count": {">=": 250}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 3
word_counts = [r["metadata"]["word_count"] for r in result]
assert all(wc >= 250 for wc in word_counts)
# Test membership operators
def test_in_filter(self):
"""Test in filter."""
filters = {"character": {"in": ["Alice", "Bob"]}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 3
characters = [r["metadata"]["character"] for r in result]
assert all(ch in ["Alice", "Bob"] for ch in characters)
def test_not_in_filter(self):
"""Test not_in filter."""
filters = {"character": {"not_in": ["Alice", "Bob"]}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 1
assert result[0]["metadata"]["character"] == "Charlie"
# Test string operators
def test_contains_filter(self):
"""Test contains filter."""
filters = {"genre": {"contains": "fiction"}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 4 # Both "fiction" and "non-fiction"
def test_starts_with_filter(self):
"""Test starts_with filter."""
filters = {"genre": {"starts_with": "non"}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 1
assert result[0]["metadata"]["genre"] == "non-fiction"
def test_ends_with_filter(self):
"""Test ends_with filter."""
filters = {"text": {"ends_with": "content"}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 4 # All sample texts end with "content"
# Test boolean operators
def test_is_true_filter(self):
"""Test is_true filter."""
filters = {"is_published": {"is_true": True}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 3
assert all(r["metadata"]["is_published"] for r in result)
def test_is_false_filter(self):
"""Test is_false filter."""
filters = {"is_published": {"is_false": False}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 1
assert not result[0]["metadata"]["is_published"]
# Test compound filters (AND logic)
def test_compound_filters(self):
"""Test multiple filters applied together (AND logic)."""
filters = {"genre": {"==": "fiction"}, "chapter": {"<=": 5}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 2
for r in result:
assert r["metadata"]["genre"] == "fiction"
assert r["metadata"]["chapter"] <= 5
def test_multiple_operators_same_field(self):
"""Test multiple operators on the same field."""
filters = {"word_count": {">=": 200, "<=": 350}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 2
for r in result:
wc = r["metadata"]["word_count"]
assert 200 <= wc <= 350
# Test edge cases
def test_missing_field_fails_filter(self):
"""Test that missing metadata fields fail filters."""
filters = {"nonexistent_field": {"==": "value"}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 0
def test_invalid_operator(self):
"""Test that invalid operators are handled gracefully."""
filters = {"chapter": {"invalid_op": 1}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 0 # Should filter out all results
def test_type_coercion_numeric(self):
"""Test numeric type coercion in comparisons."""
# Add a result with string chapter number
test_results = [
*self.sample_results,
{
"id": "doc5",
"score": 0.75,
"text": "String chapter test",
"metadata": {"chapter": "2", "genre": "test"},
},
]
filters = {"chapter": {"<": 3}}
result = self.engine.apply_filters(test_results, filters)
# Should include doc1 (chapter=1) and doc5 (chapter="2")
assert len(result) == 2
ids = [r["id"] for r in result]
assert "doc1" in ids
assert "doc5" in ids
def test_list_membership_with_nested_tags(self):
"""Test membership operations with list metadata."""
# Note: This tests the metadata structure, not list field filtering
# For list field filtering, we'd need to modify the test data
filters = {"character": {"in": ["Alice"]}}
result = self.engine.apply_filters(self.sample_results, filters)
assert len(result) == 2
assert all(r["metadata"]["character"] == "Alice" for r in result)
def test_empty_results_list(self):
"""Test filtering on empty results list."""
filters = {"chapter": {"==": 1}}
result = self.engine.apply_filters([], filters)
assert len(result) == 0
class TestPassageManagerFiltering:
"""Test suite for PassageManager filtering integration."""
def setup_method(self):
"""Setup test fixtures."""
# Mock the passage manager without actual file I/O
self.passage_manager = Mock(spec=PassageManager)
self.passage_manager.filter_engine = MetadataFilterEngine()
# Sample SearchResult objects
self.search_results = [
SearchResult(
id="doc1",
score=0.95,
text="Chapter 1 content",
metadata={"chapter": 1, "character": "Alice"},
),
SearchResult(
id="doc2",
score=0.87,
text="Chapter 5 content",
metadata={"chapter": 5, "character": "Bob"},
),
SearchResult(
id="doc3",
score=0.82,
text="Chapter 10 content",
metadata={"chapter": 10, "character": "Alice"},
),
]
def test_search_result_filtering(self):
"""Test filtering SearchResult objects."""
# Create a real PassageManager instance just for the filtering method
# We'll mock the file operations
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
filters = {"chapter": {"<=": 5}}
result = pm.filter_search_results(self.search_results, filters)
assert len(result) == 2
chapters = [r.metadata["chapter"] for r in result]
assert all(ch <= 5 for ch in chapters)
def test_filter_search_results_no_filters(self):
"""Test that None filters return all results."""
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
result = pm.filter_search_results(self.search_results, None)
assert len(result) == len(self.search_results)
def test_filter_maintains_search_result_type(self):
"""Test that filtering returns SearchResult objects."""
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
filters = {"character": {"==": "Alice"}}
result = pm.filter_search_results(self.search_results, filters)
assert len(result) == 2
for r in result:
assert isinstance(r, SearchResult)
assert r.metadata["character"] == "Alice"
# Integration tests would go here, but they require actual LEANN backend setup
# These would test the full pipeline from LeannSearcher.search() with metadata_filters
if __name__ == "__main__":
# Run basic smoke tests
engine = MetadataFilterEngine()
sample_data = [
{
"id": "test1",
"score": 0.9,
"text": "Test content",
"metadata": {"chapter": 1, "published": True},
}
]
# Test basic filtering
result = engine.apply_filters(sample_data, {"chapter": {"==": 1}})
assert len(result) == 1
print("✅ Basic filtering test passed")
result = engine.apply_filters(sample_data, {"chapter": {"==": 2}})
assert len(result) == 0
print("✅ No match filtering test passed")
print("🎉 All smoke tests passed!")

7530
uv.lock generated
View File

File diff suppressed because it is too large Load Diff