Compare commits
25 Commits
arch-eval
...
dynamic-ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d02aee6901 | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
62a5d7b31d | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
0a69118f87 | ||
|
|
880a039e1d | ||
|
|
4a39b40e72 | ||
|
|
ed5fd88a85 | ||
|
|
8f4f2b4873 | ||
|
|
6a06bd893a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
4e5b73ce7b | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 |
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report a bug in LEANN
|
||||||
|
labels: ["bug"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: A clear description of the bug
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to reproduce
|
||||||
|
placeholder: |
|
||||||
|
1. Install with...
|
||||||
|
2. Run command...
|
||||||
|
3. See error
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: error
|
||||||
|
attributes:
|
||||||
|
label: Error message
|
||||||
|
description: Paste any error messages
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: version
|
||||||
|
attributes:
|
||||||
|
label: LEANN Version
|
||||||
|
placeholder: "0.1.0"
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Docker
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Documentation
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||||
|
about: Read the docs first
|
||||||
|
- name: Discussions
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||||
|
about: Ask questions and share ideas
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a new feature for LEANN
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: What problem does this solve?
|
||||||
|
description: Describe the problem or need
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: How would you like this to work?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: example
|
||||||
|
attributes:
|
||||||
|
label: Example usage
|
||||||
|
description: Show how the API might look
|
||||||
|
render: python
|
||||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!-- Brief description of your changes -->
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Tests pass (`uv run pytest`)
|
||||||
|
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||||
|
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||||
58
.github/workflows/build-reusable.yml
vendored
58
.github/workflows/build-reusable.yml
vendored
@@ -54,6 +54,17 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
|
# ARM64 Linux builds
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.13'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.9'
|
python: '3.9'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
@@ -108,13 +119,46 @@ jobs:
|
|||||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
patchelf
|
patchelf
|
||||||
|
|
||||||
# Install Intel MKL for DiskANN
|
# Debug: Show system information
|
||||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
echo "🔍 System Information:"
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
echo "Architecture: $(uname -m)"
|
||||||
source /opt/intel/oneapi/setvars.sh
|
echo "OS: $(uname -a)"
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
echo "CPU info: $(lscpu | head -5)"
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
|
||||||
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
# Install math library based on architecture
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||||
|
|
||||||
|
if [[ "$ARCH" == "x86_64" ]]; then
|
||||||
|
# Install Intel MKL for DiskANN on x86_64
|
||||||
|
echo "📦 Installing Intel MKL for x86_64..."
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||||
|
echo "✅ Intel MKL installed for x86_64"
|
||||||
|
|
||||||
|
# Debug: Check MKL installation
|
||||||
|
echo "🔍 MKL Installation Check:"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||||
|
|
||||||
|
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||||
|
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||||
|
echo "📦 Installing OpenBLAS for ARM64..."
|
||||||
|
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||||
|
echo "✅ OpenBLAS installed for ARM64"
|
||||||
|
|
||||||
|
# Debug: Check OpenBLAS installation
|
||||||
|
echo "🔍 OpenBLAS Installation Check:"
|
||||||
|
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||||
|
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Debug: Show final library paths
|
||||||
|
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
- name: Install system dependencies (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
|
|||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -22,6 +22,7 @@ demo/experiment_results/**/*.json
|
|||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
|
!llms.txt
|
||||||
latency_breakdown*.json
|
latency_breakdown*.json
|
||||||
experiment_results/eval_results/diskann/*.json
|
experiment_results/eval_results/diskann/*.json
|
||||||
aws/
|
aws/
|
||||||
@@ -93,5 +94,11 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
batchtest.py
|
batchtest.py
|
||||||
tests/__pytest_cache__/
|
tests/__pytest_cache__/
|
||||||
tests/__pycache__/
|
tests/__pycache__/
|
||||||
|
paru-bin/
|
||||||
|
|
||||||
|
CLAUDE.md
|
||||||
|
CLAUDE.local.md
|
||||||
|
.claude/*.local.*
|
||||||
|
.claude/local/*
|
||||||
benchmarks/data/
|
benchmarks/data/
|
||||||
|
test_add/*
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
url = https://github.com/zeromq/libzmq.git
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
[submodule "packages/astchunk-leann"]
|
||||||
|
path = packages/astchunk-leann
|
||||||
|
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||||
|
|||||||
@@ -13,4 +13,5 @@ repos:
|
|||||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
60
README.md
60
README.md
@@ -8,6 +8,8 @@
|
|||||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||||
|
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||||
|
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
@@ -176,8 +178,7 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|||||||
|
|
||||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
**AST-Aware Code Chunking** - LEANN also features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files, providing improved code understanding compared to traditional text-based approaches.
|
|
||||||
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md) to learn more.
|
|
||||||
|
|
||||||
### Generation Model Setup
|
### Generation Model Setup
|
||||||
|
|
||||||
@@ -221,7 +222,8 @@ ollama pull llama3.2:1b
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### ⭐ Flexible Configuration
|
|
||||||
|
## ⭐ Flexible Configuration
|
||||||
|
|
||||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||||
|
|
||||||
@@ -477,6 +479,15 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||||
|
|
||||||
|
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||||
|
|
||||||
|
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||||
|
|
||||||
**Key features:**
|
**Key features:**
|
||||||
@@ -618,6 +629,46 @@ Options:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## 🚀 Advanced Features
|
||||||
|
|
||||||
|
### 🎯 Metadata Filtering
|
||||||
|
|
||||||
|
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Add metadata during indexing
|
||||||
|
builder.add_text(
|
||||||
|
"def authenticate_user(token): ...",
|
||||||
|
metadata={"file_extension": ".py", "lines_of_code": 25}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
|
||||||
|
|
||||||
|
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
||||||
|
|
||||||
|
### 🔍 Grep Search
|
||||||
|
|
||||||
|
For exact text matching instead of semantic search, use the `use_grep` parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
|
||||||
|
|
||||||
|
📖 **[Complete grep search guide →](docs/grep_search.md)**
|
||||||
|
|
||||||
## 🏗️ Architecture & How It Works
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -697,6 +748,9 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||||
|
|
||||||
|
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||||
|
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||||
|
|
||||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||||
|
|||||||
@@ -299,7 +299,6 @@ class BaseRAGExample(ABC):
|
|||||||
chat = LeannChat(
|
chat = LeannChat(
|
||||||
index_path,
|
index_path,
|
||||||
llm_config=self.get_llm_config(args),
|
llm_config=self.get_llm_config(args),
|
||||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
|
||||||
complexity=args.search_complexity,
|
complexity=args.search_complexity,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
from .history_data.history import ChromeHistoryReader
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,38 @@
|
|||||||
"""
|
"""Unified chunking utilities facade.
|
||||||
Chunking utilities for LEANN RAG applications.
|
|
||||||
Provides AST-aware and traditional text chunking functionality.
|
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||||
|
that both repo apps (importing `chunking`) and installed wheels share one
|
||||||
|
single implementation. When running from the repo without installation, it
|
||||||
|
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .utils import (
|
import sys
|
||||||
CODE_EXTENSIONS,
|
from pathlib import Path
|
||||||
create_ast_chunks,
|
|
||||||
create_text_chunks,
|
try:
|
||||||
create_traditional_chunks,
|
from leann.chunking_utils import (
|
||||||
detect_code_files,
|
CODE_EXTENSIONS,
|
||||||
get_language_from_extension,
|
create_ast_chunks,
|
||||||
)
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||||
|
if leann_src.exists():
|
||||||
|
sys.path.insert(0, str(leann_src))
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CODE_EXTENSIONS",
|
"CODE_EXTENSIONS",
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
from .email_data.LEANN_email_reader import EmlxReader
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
|
|||||||
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
44
benchmarks/data/README.md
Executable file
44
benchmarks/data/README.md
Executable file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
license: mit
|
||||||
|
---
|
||||||
|
|
||||||
|
# LEANN-RAG Evaluation Data
|
||||||
|
|
||||||
|
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||||
|
|
||||||
|
## Dataset Components
|
||||||
|
|
||||||
|
This dataset is structured into three main parts:
|
||||||
|
|
||||||
|
1. **Pre-built LEANN Indices**:
|
||||||
|
* `dpr/`: A pre-built index for the DPR dataset.
|
||||||
|
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||||
|
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||||
|
|
||||||
|
2. **Ground Truth Data**:
|
||||||
|
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||||
|
|
||||||
|
3. **Queries**:
|
||||||
|
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install huggingface-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir="data"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||||
@@ -12,7 +12,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
@@ -197,6 +197,25 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Batch size for HNSW batched search (0 disables batching)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||||
|
default="ollama",
|
||||||
|
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default="qwen3:1.7b",
|
||||||
|
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
@@ -318,9 +337,24 @@ def main():
|
|||||||
|
|
||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
new_results = searcher.search(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||||
|
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||||
|
answer = chat.ask(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
print(f"Answer: {answer}")
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ except ImportError:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever"
|
model_path: str = "facebook/contriever-msmarco"
|
||||||
batch_sizes: list[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
@@ -34,7 +34,7 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
@@ -179,11 +179,14 @@ class Benchmark:
|
|||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
# print shape of input_ids and attention_mask
|
||||||
|
print(f"input_ids shape: {input_ids.shape}")
|
||||||
|
print(f"attention_mask shape: {attention_mask.shape}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
# mps sync
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
torch.mps.synchronize()
|
torch.mps.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|||||||
@@ -26,6 +26,21 @@ leann build my-code-index --docs ./src --use-ast-chunking
|
|||||||
uv pip install -e "."
|
uv pip install -e "."
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### For normal users (PyPI install)
|
||||||
|
- Use `pip install leann` or `uv pip install leann`.
|
||||||
|
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||||
|
|
||||||
|
#### For developers (from source, editable)
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||||
|
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||||
|
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
### When to Use AST Chunking
|
### When to Use AST Chunking
|
||||||
|
|||||||
149
docs/grep_search.md
Normal file
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# LEANN Grep Search Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("your_index_path")
|
||||||
|
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:100]}...")
|
||||||
|
print("-" * 40)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Comparison: Semantic vs Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Semantic search - finds conceptually similar content
|
||||||
|
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
|
||||||
|
# Grep search - finds exact text matches
|
||||||
|
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to Use Grep Search
|
||||||
|
|
||||||
|
### Use Cases
|
||||||
|
|
||||||
|
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||||
|
- **Error Debugging**: Locating exact error messages or stack traces
|
||||||
|
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find function definitions
|
||||||
|
functions = searcher.search("def __init__", use_grep=True)
|
||||||
|
|
||||||
|
# Find import statements
|
||||||
|
imports = searcher.search("from sklearn import", use_grep=True)
|
||||||
|
|
||||||
|
# Find specific error types
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
|
||||||
|
# Find TODO comments
|
||||||
|
todos = searcher.search("TODO:", use_grep=True)
|
||||||
|
|
||||||
|
# Find configuration entries
|
||||||
|
configs = searcher.search("server_port=", use_grep=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||||
|
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||||
|
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||||
|
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||||
|
|
||||||
|
### Search Process
|
||||||
|
|
||||||
|
```
|
||||||
|
Query: "def train_model"
|
||||||
|
↓
|
||||||
|
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||||
|
↓
|
||||||
|
Parse matching JSON lines
|
||||||
|
↓
|
||||||
|
Calculate scores based on term frequency
|
||||||
|
↓
|
||||||
|
Return top_k results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scoring Algorithm
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Term frequency in document
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
```
|
||||||
|
|
||||||
|
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### Grep Command Not Found
|
||||||
|
```
|
||||||
|
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**: Install grep on your system:
|
||||||
|
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||||
|
- **macOS**: grep is pre-installed
|
||||||
|
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||||
|
|
||||||
|
#### No Results Found
|
||||||
|
```python
|
||||||
|
# Check if your query exists in the raw data
|
||||||
|
results = searcher.search("your_query", use_grep=True)
|
||||||
|
if not results:
|
||||||
|
print("No exact matches found. Try:")
|
||||||
|
print("1. Check spelling and case")
|
||||||
|
print("2. Use partial terms")
|
||||||
|
print("3. Switch to semantic search")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
Demonstrates grep search for exact text matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
def demonstrate_grep_search():
|
||||||
|
# Initialize searcher
|
||||||
|
searcher = LeannSearcher("my_index")
|
||||||
|
|
||||||
|
print("=== Function Search ===")
|
||||||
|
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||||
|
for i, result in enumerate(functions, 1):
|
||||||
|
print(f"{i}. Score: {result.score}")
|
||||||
|
print(f" Preview: {result.text[:60]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=== Error Search ===")
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||||
|
for result in errors:
|
||||||
|
print(f"Content: {result.text.strip()}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demonstrate_grep_search()
|
||||||
|
```
|
||||||
300
docs/metadata_filtering.md
Normal file
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
# LEANN Metadata Filtering Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Adding Metadata to Your Documents
|
||||||
|
|
||||||
|
When building your index, add metadata to each text chunk:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
|
||||||
|
# Add text with metadata
|
||||||
|
builder.add_text(
|
||||||
|
text="Chapter 1: Alice falls down the rabbit hole",
|
||||||
|
metadata={
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"themes": ["adventure", "curiosity"],
|
||||||
|
"word_count": 150
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.build_index("alice_in_wonderland_index")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Searching with Metadata Filters
|
||||||
|
|
||||||
|
Use the `metadata_filters` parameter in search calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="What happens to Alice?",
|
||||||
|
top_k=10,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||||
|
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Filter Syntax
|
||||||
|
|
||||||
|
### Basic Structure
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"field_name": {"operator": value},
|
||||||
|
"another_field": {"operator": value}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Operators
|
||||||
|
|
||||||
|
#### Comparison Operators
|
||||||
|
- `"=="`: Equal to
|
||||||
|
- `"!="`: Not equal to
|
||||||
|
- `"<"`: Less than
|
||||||
|
- `"<="`: Less than or equal
|
||||||
|
- `">"`: Greater than
|
||||||
|
- `">="`: Greater than or equal
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||||
|
{"page": {">": 100}} # Pages after 100
|
||||||
|
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||||
|
{"word_count": {"<": 500}} # Short passages
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Membership Operators
|
||||||
|
- `"in"`: Value is in list
|
||||||
|
- `"not_in"`: Value is not in list
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||||
|
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||||
|
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||||
|
```
|
||||||
|
|
||||||
|
#### String Operators
|
||||||
|
- `"contains"`: String contains substring
|
||||||
|
- `"starts_with"`: String starts with prefix
|
||||||
|
- `"ends_with"`: String ends with suffix
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||||
|
{"filename": {"ends_with": ".py"}} # Python files
|
||||||
|
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Boolean Operators
|
||||||
|
- `"is_true"`: Field is truthy
|
||||||
|
- `"is_false"`: Field is falsy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"is_published": {"is_true": True}} # Published content
|
||||||
|
{"is_draft": {"is_false": False}} # Not drafts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Operators on Same Field
|
||||||
|
|
||||||
|
You can apply multiple operators to the same field (AND logic):
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"word_count": {
|
||||||
|
">=": 100, # At least 100 words
|
||||||
|
"<=": 500 # At most 500 words
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compound Filters
|
||||||
|
|
||||||
|
Multiple fields are combined with AND logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"chapter": {"<=": 10}, # Up to chapter 10
|
||||||
|
"character": {"==": "Alice"}, # About Alice
|
||||||
|
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Case Examples
|
||||||
|
|
||||||
|
### 1. Spoiler-Free Book Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Reader has only read up to chapter 5
|
||||||
|
def search_spoiler_free(query, max_chapter):
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": max_chapter},
|
||||||
|
"spoiler_level": {"in": ["none", "low"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Document Management by Date
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find recent documents
|
||||||
|
recent_docs = searcher.search(
|
||||||
|
query="project updates",
|
||||||
|
metadata_filters={
|
||||||
|
"date": {">=": "2024-01-01"},
|
||||||
|
"document_type": {"==": "report"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Code Search by File Type
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search only Python files
|
||||||
|
python_code = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Content Filtering by Audience
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Age-appropriate content
|
||||||
|
family_content = searcher.search(
|
||||||
|
query="adventure stories",
|
||||||
|
metadata_filters={
|
||||||
|
"age_rating": {"in": ["G", "PG"]},
|
||||||
|
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Multi-Book Series Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search across first 3 books only
|
||||||
|
early_series = searcher.search(
|
||||||
|
query="character development",
|
||||||
|
metadata_filters={
|
||||||
|
"series": {"==": "Harry Potter"},
|
||||||
|
"book_number": {"<=": 3}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the Example
|
||||||
|
|
||||||
|
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Don't forget to set up the environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Run the spoiler-free book RAG example
|
||||||
|
uv run examples/spoiler_free_book_rag.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This example demonstrates:
|
||||||
|
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||||
|
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||||
|
- Different scenarios for readers at various points in the book
|
||||||
|
|
||||||
|
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||||
|
|
||||||
|
## Advanced Patterns
|
||||||
|
|
||||||
|
### Custom Chunking with metadata
|
||||||
|
|
||||||
|
```python
|
||||||
|
def chunk_book_with_metadata(book_text, book_info):
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||||
|
# Extract entities, themes, etc.
|
||||||
|
characters = extract_characters(chapter_text)
|
||||||
|
themes = classify_themes(chapter_text)
|
||||||
|
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||||
|
|
||||||
|
# Create chunks with rich metadata
|
||||||
|
for paragraph in split_paragraphs(chapter_text):
|
||||||
|
chunks.append({
|
||||||
|
"text": paragraph,
|
||||||
|
"metadata": {
|
||||||
|
"book_title": book_info["title"],
|
||||||
|
"chapter": chapter_num,
|
||||||
|
"characters": characters,
|
||||||
|
"themes": themes,
|
||||||
|
"spoiler_level": spoiler_level,
|
||||||
|
"word_count": len(paragraph.split()),
|
||||||
|
"reading_level": calculate_reading_level(paragraph)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Efficient Filtering Strategies
|
||||||
|
|
||||||
|
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||||
|
|
||||||
|
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||||
|
|
||||||
|
### Best Practices
|
||||||
|
|
||||||
|
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||||
|
|
||||||
|
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||||
|
|
||||||
|
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||||
|
|
||||||
|
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||||
|
|
||||||
|
### Adding Metadata to Existing Indices
|
||||||
|
|
||||||
|
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Read existing passages and add metadata
|
||||||
|
def add_metadata_to_existing_chunks(chunks):
|
||||||
|
for chunk in chunks:
|
||||||
|
# Extract or assign metadata based on content
|
||||||
|
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Rebuild index with metadata
|
||||||
|
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
for chunk in enhanced_chunks:
|
||||||
|
builder.add_text(chunk["text"], chunk["metadata"])
|
||||||
|
builder.build_index("enhanced_index")
|
||||||
|
```
|
||||||
380
examples/dynamic_add_leann_no_recompute.py
Normal file
380
examples/dynamic_add_leann_no_recompute.py
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
"""
|
||||||
|
Dynamic add example for LEANN using HNSW backend without recompute.
|
||||||
|
|
||||||
|
- Builds a base index from a directory of documents
|
||||||
|
- Incrementally adds new documents without recomputing stored embeddings
|
||||||
|
|
||||||
|
Defaults:
|
||||||
|
- Base data: /Users/yichuan/Desktop/code/LEANN/leann/data
|
||||||
|
- Incremental data: /Users/yichuan/Desktop/code/LEANN/leann/test_add
|
||||||
|
- Index path: <index_dir>/documents.leann
|
||||||
|
|
||||||
|
Usage examples:
|
||||||
|
uv run python examples/dynamic_add_leann_no_recompute.py --build-base \
|
||||||
|
--base-dir /Users/yichuan/Desktop/code/LEANN/leann/data \
|
||||||
|
--index-dir ./test_doc_files
|
||||||
|
|
||||||
|
uv run python examples/dynamic_add_leann_no_recompute.py --add-incremental \
|
||||||
|
--add-dir /Users/yichuan/Desktop/code/LEANN/leann/test_add \
|
||||||
|
--index-dir ./test_doc_files
|
||||||
|
|
||||||
|
Quick recompute test (both true):
|
||||||
|
# Recompute build
|
||||||
|
uv run python examples/dynamic_add_leann_no_recompute.py --build-base \
|
||||||
|
--recompute-build --ef-construction 200 \
|
||||||
|
--base-dir /Users/yichuan/Desktop/code/LEANN/leann/data \
|
||||||
|
--index-dir ./test_doc_files --index-name documents.leann
|
||||||
|
|
||||||
|
# Recompute add
|
||||||
|
uv run python examples/dynamic_add_leann_no_recompute.py --add-incremental \
|
||||||
|
--recompute-add --ef-construction 32 \
|
||||||
|
--add-dir /Users/yichuan/Desktop/code/LEANN/leann/test_add \
|
||||||
|
--index-dir ./test_doc_files --index-name documents.leann
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Ensure we can import from the local packages and apps folders
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
CORE_SRC = ROOT / "packages" / "leann-core" / "src"
|
||||||
|
HNSW_PKG_DIR = ROOT / "packages" / "leann-backend-hnsw"
|
||||||
|
APPS_DIR = ROOT / "apps"
|
||||||
|
|
||||||
|
|
||||||
|
# Prefer the installed backend if available (it contains the compiled extension)
|
||||||
|
def _prefer_installed(pkg_name: str) -> bool:
|
||||||
|
try:
|
||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
spec = importlib.util.find_spec(pkg_name)
|
||||||
|
if spec and spec.origin and "site-packages" in spec.origin:
|
||||||
|
# ensure the faiss shim/extension is importable from the installed package
|
||||||
|
importlib.import_module(f"{pkg_name}.faiss")
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Prepend paths, but only add the repo backend if the installed one is not present
|
||||||
|
paths_to_prepend = [CORE_SRC, APPS_DIR]
|
||||||
|
if not _prefer_installed("leann_backend_hnsw"):
|
||||||
|
paths_to_prepend.insert(1, HNSW_PKG_DIR)
|
||||||
|
|
||||||
|
for p in paths_to_prepend:
|
||||||
|
p_str = str(p)
|
||||||
|
if p_str not in sys.path:
|
||||||
|
sys.path.insert(0, p_str)
|
||||||
|
|
||||||
|
# Defer non-stdlib imports until after sys.path setup within functions (avoid E402)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_documents(data_dir: str, required_exts: Optional[list[str]] = None) -> list[Any]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader # type: ignore
|
||||||
|
|
||||||
|
reader_kwargs: dict[str, Any] = {"recursive": True, "encoding": "utf-8"}
|
||||||
|
if required_exts:
|
||||||
|
reader_kwargs["required_exts"] = required_exts
|
||||||
|
documents = SimpleDirectoryReader(data_dir, **reader_kwargs).load_data(show_progress=True)
|
||||||
|
return documents
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_index_dir(index_dir: Path) -> None:
|
||||||
|
index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _index_files(index_path: Path) -> tuple[Path, Path, Path]:
|
||||||
|
"""Return (passages.jsonl, passages.idx, index.index) paths for a given index base path.
|
||||||
|
|
||||||
|
Note: HNSWBackend writes the FAISS index using the stem (without .leann),
|
||||||
|
i.e., for base 'documents.leann' the file is 'documents.index'. We prefer the
|
||||||
|
existing file among candidates.
|
||||||
|
"""
|
||||||
|
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||||
|
offsets_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||||
|
candidate_name_index = index_path.parent / f"{index_path.name}.index"
|
||||||
|
candidate_stem_index = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
index_file = candidate_stem_index if candidate_stem_index.exists() else candidate_name_index
|
||||||
|
return passages_file, offsets_file, index_file
|
||||||
|
|
||||||
|
|
||||||
|
def _read_meta(index_path: Path) -> dict[str, Any]:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
if not meta_path.exists():
|
||||||
|
raise FileNotFoundError(f"Metadata file not found: {meta_path}")
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _autodetect_index_base(index_dir: Path) -> Optional[Path]:
|
||||||
|
"""If exactly one *.leann.meta.json exists, return its base path (without .meta.json)."""
|
||||||
|
candidates = list(index_dir.glob("*.leann.meta.json"))
|
||||||
|
if len(candidates) == 1:
|
||||||
|
meta = candidates[0]
|
||||||
|
base = meta.with_suffix("") # remove .json
|
||||||
|
base = base.with_suffix("") # remove .meta
|
||||||
|
return base
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_offset_map(offsets_file: Path) -> dict[str, int]:
|
||||||
|
if not offsets_file.exists():
|
||||||
|
return {}
|
||||||
|
with open(offsets_file, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _next_numeric_id(existing_ids: list[str]) -> int:
|
||||||
|
numeric_ids = [int(x) for x in existing_ids if x.isdigit()]
|
||||||
|
if not numeric_ids:
|
||||||
|
return 0
|
||||||
|
return max(numeric_ids) + 1
|
||||||
|
|
||||||
|
|
||||||
|
def build_base_index(
|
||||||
|
base_dir: str,
|
||||||
|
index_dir: str,
|
||||||
|
index_name: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
chunk_size: int,
|
||||||
|
chunk_overlap: int,
|
||||||
|
file_types: Optional[list[str]] = None,
|
||||||
|
max_items: int = -1,
|
||||||
|
ef_construction: Optional[int] = None,
|
||||||
|
recompute_build: bool = False,
|
||||||
|
) -> str:
|
||||||
|
print(f"Building base index from: {base_dir}")
|
||||||
|
documents = _load_documents(base_dir, required_exts=file_types)
|
||||||
|
if not documents:
|
||||||
|
raise ValueError(f"No documents found in base_dir: {base_dir}")
|
||||||
|
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
if max_items > 0 and len(texts) > max_items:
|
||||||
|
texts = texts[:max_items]
|
||||||
|
print(f"Limiting to {max_items} chunks")
|
||||||
|
|
||||||
|
index_dir_path = Path(index_dir)
|
||||||
|
_ensure_index_dir(index_dir_path)
|
||||||
|
index_path = index_dir_path / index_name
|
||||||
|
|
||||||
|
print("Creating HNSW index (non-compact)...")
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_recompute=recompute_build,
|
||||||
|
is_compact=False,
|
||||||
|
efConstruction=(ef_construction if ef_construction is not None else 200),
|
||||||
|
)
|
||||||
|
for t in texts:
|
||||||
|
builder.add_text(t)
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
# Register for discovery
|
||||||
|
register_project_directory(Path.cwd())
|
||||||
|
|
||||||
|
print(f"Base index built at: {index_path}")
|
||||||
|
return str(index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def add_incremental(
|
||||||
|
add_dir: str,
|
||||||
|
index_dir: str,
|
||||||
|
index_name: Optional[str] = None,
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
embedding_mode: Optional[str] = None,
|
||||||
|
chunk_size: int = 256,
|
||||||
|
chunk_overlap: int = 128,
|
||||||
|
file_types: Optional[list[str]] = None,
|
||||||
|
max_items: int = -1,
|
||||||
|
ef_construction: Optional[int] = None,
|
||||||
|
recompute_add: bool = False,
|
||||||
|
) -> str:
|
||||||
|
print(f"Adding incremental data from: {add_dir}")
|
||||||
|
index_dir_path = Path(index_dir)
|
||||||
|
index_path = index_dir_path / (index_name or "documents.leann")
|
||||||
|
|
||||||
|
# If specified base doesn't exist, try to auto-detect an existing base
|
||||||
|
try:
|
||||||
|
_read_meta(index_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
auto_base = _autodetect_index_base(index_dir_path)
|
||||||
|
if auto_base is not None:
|
||||||
|
print(f"Auto-detected index base: {auto_base.name}")
|
||||||
|
index_path = auto_base
|
||||||
|
_read_meta(index_path)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No index metadata found for base '{index_path.name}'. Build base first with --build-base "
|
||||||
|
f"or provide --index-name to match an existing index (e.g., 'test_doc_files.leann')."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare validated context from core (checks backend/no-recompute and resolves embedding defaults)
|
||||||
|
from leann.api import create_incremental_add_context, incremental_add_texts_with_context
|
||||||
|
|
||||||
|
ctx = create_incremental_add_context(
|
||||||
|
str(index_path),
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
data_dir=add_dir,
|
||||||
|
required_exts=file_types,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
max_items=max_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use prepared texts from context to perform the add
|
||||||
|
prepared_texts = ctx.prepared_texts or []
|
||||||
|
if not prepared_texts:
|
||||||
|
print("No new chunks to add.")
|
||||||
|
return str(index_path)
|
||||||
|
|
||||||
|
added = incremental_add_texts_with_context(
|
||||||
|
ctx,
|
||||||
|
prepared_texts,
|
||||||
|
ef_construction=ef_construction,
|
||||||
|
recompute=recompute_add,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Incremental add completed. Added {added} chunks. Index: {index_path}")
|
||||||
|
return str(index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Dynamic add to LEANN HNSW index without recompute",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--build-base", action="store_true", help="Build base index")
|
||||||
|
parser.add_argument("--add-incremental", action="store_true", help="Add incremental data")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-dir",
|
||||||
|
type=str,
|
||||||
|
default="/Users/yichuan/Desktop/code/LEANN/leann/data",
|
||||||
|
help="Base data directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--add-dir",
|
||||||
|
type=str,
|
||||||
|
default="/Users/yichuan/Desktop/code/LEANN/leann/test_add",
|
||||||
|
help="Incremental data directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default="./test_doc_files",
|
||||||
|
help="Directory containing the index",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-name",
|
||||||
|
type=str,
|
||||||
|
default="documents.leann",
|
||||||
|
help=(
|
||||||
|
"Index base file name. If you built via document_rag.py, use 'test_doc_files.leann'. "
|
||||||
|
"Default: documents.leann"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="facebook/contriever",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--chunk-size", type=int, default=256)
|
||||||
|
parser.add_argument("--chunk-overlap", type=int, default=128)
|
||||||
|
parser.add_argument("--file-types", nargs="+", default=None)
|
||||||
|
parser.add_argument("--max-items", type=int, default=-1)
|
||||||
|
parser.add_argument("--ef-construction", type=int, default=32)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recompute-add", action="store_true", help="Enable recompute-mode add (non-compact only)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recompute-build",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable recompute-mode base build (non-compact only)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.build_base and not args.add_incremental:
|
||||||
|
print("Nothing to do. Use --build-base and/or --add-incremental.")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_path_str: Optional[str] = None
|
||||||
|
|
||||||
|
if args.build_base:
|
||||||
|
index_path_str = build_base_index(
|
||||||
|
base_dir=args.base_dir,
|
||||||
|
index_dir=args.index_dir,
|
||||||
|
index_name=args.index_name,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
file_types=args.file_types,
|
||||||
|
max_items=args.max_items,
|
||||||
|
ef_construction=args.ef_construction,
|
||||||
|
recompute_build=args.recompute_build,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.add_incremental:
|
||||||
|
index_path_str = add_incremental(
|
||||||
|
add_dir=args.add_dir,
|
||||||
|
index_dir=args.index_dir,
|
||||||
|
index_name=args.index_name,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
file_types=args.file_types,
|
||||||
|
max_items=args.max_items,
|
||||||
|
ef_construction=args.ef_construction,
|
||||||
|
recompute_add=args.recompute_add,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional: quick test query using searcher
|
||||||
|
if index_path_str:
|
||||||
|
try:
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path_str)
|
||||||
|
query = "what is LEANN?"
|
||||||
|
if args.add_incremental:
|
||||||
|
query = "what is the multi vector search and how it works?"
|
||||||
|
results = searcher.search(query, top_k=5)
|
||||||
|
if results:
|
||||||
|
print(f"Sample result: {results[0].text[:80]}...")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
35
examples/grep_search_example.py
Normal file
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
|
||||||
|
Shows how to use grep-based text search instead of semantic search.
|
||||||
|
Useful when you need exact text matches rather than meaning-based results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Load your index
|
||||||
|
searcher = LeannSearcher("my-documents.leann")
|
||||||
|
|
||||||
|
# Regular semantic search
|
||||||
|
print("=== Semantic Search ===")
|
||||||
|
results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score:.3f}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Grep-based search for exact text matches
|
||||||
|
print("=== Grep Search ===")
|
||||||
|
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Find specific error messages
|
||||||
|
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||||
|
|
||||||
|
# Search for function definitions
|
||||||
|
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||||
|
print(f"Found {len(func_results)} class definitions")
|
||||||
250
examples/spoiler_free_book_rag.py
Normal file
250
examples/spoiler_free_book_rag.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
||||||
|
|
||||||
|
This example demonstrates how to use LEANN's metadata filtering to create
|
||||||
|
a spoiler-free book RAG system where users can search for information
|
||||||
|
up to a specific chapter they've read.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python spoiler_free_book_rag.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Add LEANN to path (adjust path as needed)
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Create sample book chunks with metadata for demonstration.
|
||||||
|
|
||||||
|
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
||||||
|
and extract chapter boundaries, character mentions, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_title: Title of the book
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chunk dictionaries with text and metadata
|
||||||
|
"""
|
||||||
|
# Sample book chunks with metadata
|
||||||
|
# In practice, you'd use proper text processing libraries
|
||||||
|
|
||||||
|
sample_chunks = [
|
||||||
|
{
|
||||||
|
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 1,
|
||||||
|
"characters": ["Alice", "Sister"],
|
||||||
|
"themes": ["boredom", "curiosity"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 2,
|
||||||
|
"characters": ["Alice", "White Rabbit"],
|
||||||
|
"themes": ["decision", "surprise", "magic"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 2,
|
||||||
|
"page": 15,
|
||||||
|
"characters": ["Alice"],
|
||||||
|
"themes": ["falling", "wonder", "transformation"],
|
||||||
|
"location": "rabbit hole",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 6,
|
||||||
|
"page": 85,
|
||||||
|
"characters": ["Alice", "Cheshire Cat"],
|
||||||
|
"themes": ["madness", "philosophy", "identity"],
|
||||||
|
"location": "Duchess's house",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 8,
|
||||||
|
"page": 120,
|
||||||
|
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
||||||
|
"themes": ["justice", "absurdity", "authority"],
|
||||||
|
"location": "Queen's court",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 12,
|
||||||
|
"page": 180,
|
||||||
|
"characters": ["Alice", "Sister", "Rabbit"],
|
||||||
|
"themes": ["revelation", "reality", "growth"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return sample_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a LEANN index with book chunks that include spoiler metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_chunks: List of book chunks with metadata
|
||||||
|
index_name: Name for the index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the built index
|
||||||
|
"""
|
||||||
|
print(f"📚 Building spoiler-free book index: {index_name}")
|
||||||
|
|
||||||
|
# Initialize LEANN builder
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add each chunk with its metadata
|
||||||
|
for chunk in book_chunks:
|
||||||
|
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
index_path = f"{index_name}_book_index"
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
print(f"✅ Index built successfully: {index_path}")
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def spoiler_free_search(
|
||||||
|
index_path: str,
|
||||||
|
query: str,
|
||||||
|
max_chapter: int,
|
||||||
|
character_filter: Optional[list[str]] = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Perform a spoiler-free search on the book index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: Search query
|
||||||
|
max_chapter: Maximum chapter number to include
|
||||||
|
character_filter: Optional list of characters to focus on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results safe for the reader
|
||||||
|
"""
|
||||||
|
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
metadata_filters = {"chapter": {"<=": max_chapter}}
|
||||||
|
|
||||||
|
if character_filter:
|
||||||
|
metadata_filters["characters"] = {"contains": character_filter[0]}
|
||||||
|
|
||||||
|
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def demo_spoiler_free_rag():
|
||||||
|
"""
|
||||||
|
Demonstrate the spoiler-free book RAG system.
|
||||||
|
"""
|
||||||
|
print("🎭 Spoiler-Free Book RAG Demo")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
# Step 1: Prepare book data
|
||||||
|
book_title = "Alice's Adventures in Wonderland"
|
||||||
|
book_chunks = chunk_book_with_metadata(book_title)
|
||||||
|
|
||||||
|
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
||||||
|
|
||||||
|
# Step 2: Build the index (in practice, this would be done once)
|
||||||
|
try:
|
||||||
|
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
||||||
|
print(
|
||||||
|
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Demonstrate various spoiler-free searches
|
||||||
|
search_scenarios = [
|
||||||
|
{
|
||||||
|
"description": "Reader who has only read Chapter 1",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read up to Chapter 5",
|
||||||
|
"query": "Tell me about Alice's adventures",
|
||||||
|
"max_chapter": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read most of the book",
|
||||||
|
"query": "What does the Cheshire Cat represent?",
|
||||||
|
"max_chapter": 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read the whole book",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 12,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for scenario in search_scenarios:
|
||||||
|
print(f"\n📚 Scenario: {scenario['description']}")
|
||||||
|
print(f" Query: {scenario['query']}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = spoiler_free_search(
|
||||||
|
index_path=index_path,
|
||||||
|
query=scenario["query"],
|
||||||
|
max_chapter=scenario["max_chapter"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📄 Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results[:3], 1): # Show top 3
|
||||||
|
chapter = result.metadata.get("chapter", "?")
|
||||||
|
location = result.metadata.get("location", "?")
|
||||||
|
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Search failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("📚 LEANN Spoiler-Free Book RAG Example")
|
||||||
|
print("=====================================")
|
||||||
|
|
||||||
|
try:
|
||||||
|
demo_spoiler_free_rag()
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error running demo: {e}")
|
||||||
28
llms.txt
Normal file
28
llms.txt
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# llms.txt — LEANN MCP and Agent Integration
|
||||||
|
product: LEANN
|
||||||
|
homepage: https://github.com/yichuan-w/LEANN
|
||||||
|
contact: https://github.com/yichuan-w/LEANN/issues
|
||||||
|
|
||||||
|
# Installation
|
||||||
|
install: uv tool install leann-core --with leann
|
||||||
|
|
||||||
|
# MCP Server Entry Point
|
||||||
|
mcp.server: leann_mcp
|
||||||
|
mcp.protocol_version: 2024-11-05
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
mcp.tools: leann_list, leann_search
|
||||||
|
|
||||||
|
mcp.tool.leann_list.description: List available LEANN indexes
|
||||||
|
mcp.tool.leann_list.input: {}
|
||||||
|
|
||||||
|
mcp.tool.leann_search.description: Semantic search across a named LEANN index
|
||||||
|
mcp.tool.leann_search.input.index_name: string, required
|
||||||
|
mcp.tool.leann_search.input.query: string, required
|
||||||
|
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
|
||||||
|
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
note: Build indexes with `leann build <name> --docs <files...>` before searching.
|
||||||
|
example.add: claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
example.verify: claude mcp list | cat
|
||||||
1
packages/astchunk-leann
Submodule
1
packages/astchunk-leann
Submodule
Submodule packages/astchunk-leann added at ad9afa07b9
@@ -4,8 +4,8 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
dependencies = ["leann-core==0.3.2", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: c593831474...19f9603c72
@@ -49,9 +49,28 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
# Disable additional SIMD versions to speed up compilation
|
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
|
||||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# ARM64-specific configuration
|
||||||
|
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||||
|
message(STATUS "Configuring Faiss for ARM64 architecture")
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
|
||||||
|
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
|
||||||
|
else()
|
||||||
|
# Use generic optimization for other ARM64 platforms (like macOS)
|
||||||
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
|
||||||
|
message(STATUS "Using ARM64-compatible Faiss submodule")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Additional optimization options from INSTALL.md
|
# Additional optimization options from INSTALL.md
|
||||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ from leann.registry import register_backend
|
|||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr
|
||||||
|
from .prune_index import prune_embeddings_preserve_graph_inplace
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -89,8 +91,16 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
index_file = index_dir / f"{index_prefix}.index"
|
index_file = index_dir / f"{index_prefix}.index"
|
||||||
faiss.write_index(index, str(index_file))
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
if self.is_compact:
|
if self.is_recompute:
|
||||||
self._convert_to_csr(index_file)
|
if self.is_compact:
|
||||||
|
self._convert_to_csr(index_file)
|
||||||
|
else:
|
||||||
|
# Non-compact format: prune only embeddings, keep original graph
|
||||||
|
ok = prune_embeddings_preserve_graph_inplace(str(index_file))
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Pruning embeddings while preserving graph failed for non-compact index"
|
||||||
|
)
|
||||||
|
|
||||||
def _convert_to_csr(self, index_file: Path):
|
def _convert_to_csr(self, index_file: Path):
|
||||||
"""Convert built index to CSR format"""
|
"""Convert built index to CSR format"""
|
||||||
@@ -147,7 +157,13 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
self.is_pruned
|
self.is_pruned
|
||||||
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
) # In C++ code, it's called is_recompute, but it's only for loading IIUC.
|
||||||
|
|
||||||
self._index = faiss.read_index(str(index_file), faiss.IO_FLAG_MMAP, hnsw_config)
|
# If pruned (recompute mode), explicitly skip storage to avoid reading
|
||||||
|
# the pruned section. Still allow MMAP for graph.
|
||||||
|
io_flags = faiss.IO_FLAG_MMAP
|
||||||
|
if self.is_pruned:
|
||||||
|
io_flags |= faiss.IO_FLAG_SKIP_STORAGE
|
||||||
|
|
||||||
|
self._index = faiss.read_index(str(index_file), io_flags, hnsw_config)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -236,6 +252,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||||
|
|
||||||
|
search_time = time.time()
|
||||||
self._index.search(
|
self._index.search(
|
||||||
query.shape[0],
|
query.shape[0],
|
||||||
faiss.swig_ptr(query),
|
faiss.swig_ptr(query),
|
||||||
@@ -244,7 +261,60 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
faiss.swig_ptr(labels),
|
faiss.swig_ptr(labels),
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
search_time = time.time() - search_time
|
||||||
|
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Helper API for incremental add (Python-level) ----------
|
||||||
|
def add_vectors(
|
||||||
|
index_file_path: str,
|
||||||
|
embeddings: np.ndarray,
|
||||||
|
*,
|
||||||
|
ef_construction: Optional[int] = None,
|
||||||
|
recompute: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Append vectors to an existing non-compact HNSW index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_file_path: Path to the HNSW .index file
|
||||||
|
embeddings: float32 numpy array (N, D)
|
||||||
|
ef_construction: Optional override for efConstruction during insertion
|
||||||
|
recompute: Reserved for future use to control insertion-time recompute behaviors
|
||||||
|
"""
|
||||||
|
from . import faiss # type: ignore
|
||||||
|
|
||||||
|
if embeddings.dtype != np.float32:
|
||||||
|
embeddings = embeddings.astype(np.float32)
|
||||||
|
if not embeddings.flags.c_contiguous:
|
||||||
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Load index normally to ensure storage is present; toggle is_recompute on the object
|
||||||
|
index = faiss.read_index(str(index_file_path), faiss.IO_FLAG_MMAP)
|
||||||
|
|
||||||
|
# Best-effort: explicitly set flag on the object if the binding exposes it
|
||||||
|
try:
|
||||||
|
index.is_recompute = bool(recompute)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if ef_construction is not None:
|
||||||
|
index.hnsw.efConstruction = int(ef_construction)
|
||||||
|
except Exception:
|
||||||
|
# Best-effort; ignore if backend doesn't expose setter
|
||||||
|
pass
|
||||||
|
|
||||||
|
# For non-compact HNSW, calling add directly is sufficient. When is_recompute is set
|
||||||
|
# (via config or attribute), FAISS will run the insertion/search path accordingly.
|
||||||
|
# To strictly follow per-point insert semantics in recompute mode, add one-by-one.
|
||||||
|
if recompute:
|
||||||
|
# Insert row by row
|
||||||
|
n = embeddings.shape[0]
|
||||||
|
for i in range(n):
|
||||||
|
row = embeddings[i : i + 1]
|
||||||
|
index.add(1, faiss.swig_ptr(row))
|
||||||
|
else:
|
||||||
|
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||||
|
faiss.write_index(index, str(index_file_path))
|
||||||
|
|||||||
149
packages/leann-backend-hnsw/leann_backend_hnsw/prune_index.py
Normal file
149
packages/leann-backend-hnsw/leann_backend_hnsw/prune_index.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import os
|
||||||
|
import struct
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .convert_to_csr import (
|
||||||
|
EXPECTED_HNSW_FOURCCS,
|
||||||
|
NULL_INDEX_FOURCC,
|
||||||
|
read_struct,
|
||||||
|
read_vector_raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_vector_raw(f_out, count: int, data_bytes: bytes) -> None:
|
||||||
|
"""Write a vector in the same binary layout as read_vector_raw reads: <Q count> + raw bytes."""
|
||||||
|
f_out.write(struct.pack("<Q", count))
|
||||||
|
if count > 0 and data_bytes:
|
||||||
|
f_out.write(data_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
def prune_embeddings_preserve_graph(input_filename: str, output_filename: str) -> bool:
|
||||||
|
"""
|
||||||
|
Copy an original (non-compact) HNSW index file while pruning the trailing embedding storage.
|
||||||
|
Preserves the graph structure and metadata exactly; only writes a NULL storage marker instead of
|
||||||
|
the original storage fourcc and payload.
|
||||||
|
|
||||||
|
Returns True on success.
|
||||||
|
"""
|
||||||
|
print(f"Pruning embeddings from {input_filename} to {output_filename}")
|
||||||
|
print("--------------------------------")
|
||||||
|
# running in mode is-recompute=True and is-compact=False
|
||||||
|
in_path = Path(input_filename)
|
||||||
|
out_path = Path(output_filename)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(in_path, "rb") as f_in, open(out_path, "wb") as f_out:
|
||||||
|
# Header
|
||||||
|
index_fourcc = read_struct(f_in, "<I")
|
||||||
|
if index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||||
|
# Still proceed, but this is unexpected
|
||||||
|
pass
|
||||||
|
f_out.write(struct.pack("<I", index_fourcc))
|
||||||
|
|
||||||
|
d = read_struct(f_in, "<i")
|
||||||
|
ntotal_hdr = read_struct(f_in, "<q")
|
||||||
|
dummy1 = read_struct(f_in, "<q")
|
||||||
|
dummy2 = read_struct(f_in, "<q")
|
||||||
|
is_trained = read_struct(f_in, "?")
|
||||||
|
metric_type = read_struct(f_in, "<i")
|
||||||
|
f_out.write(struct.pack("<i", d))
|
||||||
|
f_out.write(struct.pack("<q", ntotal_hdr))
|
||||||
|
f_out.write(struct.pack("<q", dummy1))
|
||||||
|
f_out.write(struct.pack("<q", dummy2))
|
||||||
|
f_out.write(struct.pack("<?", is_trained))
|
||||||
|
f_out.write(struct.pack("<i", metric_type))
|
||||||
|
|
||||||
|
if metric_type > 1:
|
||||||
|
metric_arg = read_struct(f_in, "<f")
|
||||||
|
f_out.write(struct.pack("<f", metric_arg))
|
||||||
|
|
||||||
|
# Vectors: assign_probas (double), cum_nneighbor_per_level (int32), levels (int32)
|
||||||
|
cnt, data = read_vector_raw(f_in, "d")
|
||||||
|
_write_vector_raw(f_out, cnt, data)
|
||||||
|
|
||||||
|
cnt, data = read_vector_raw(f_in, "i")
|
||||||
|
_write_vector_raw(f_out, cnt, data)
|
||||||
|
|
||||||
|
cnt, data = read_vector_raw(f_in, "i")
|
||||||
|
_write_vector_raw(f_out, cnt, data)
|
||||||
|
|
||||||
|
# Probe potential extra alignment/flag byte present in some original formats
|
||||||
|
probe = f_in.read(1)
|
||||||
|
if probe:
|
||||||
|
if probe == b"\x00":
|
||||||
|
# Preserve this unexpected 0x00 byte
|
||||||
|
f_out.write(probe)
|
||||||
|
else:
|
||||||
|
# Likely part of the next vector; rewind
|
||||||
|
f_in.seek(-1, os.SEEK_CUR)
|
||||||
|
|
||||||
|
# Offsets (uint64) and neighbors (int32)
|
||||||
|
cnt, data = read_vector_raw(f_in, "Q")
|
||||||
|
_write_vector_raw(f_out, cnt, data)
|
||||||
|
|
||||||
|
cnt, data = read_vector_raw(f_in, "i")
|
||||||
|
_write_vector_raw(f_out, cnt, data)
|
||||||
|
|
||||||
|
# Scalar params
|
||||||
|
entry_point = read_struct(f_in, "<i")
|
||||||
|
max_level = read_struct(f_in, "<i")
|
||||||
|
ef_construction = read_struct(f_in, "<i")
|
||||||
|
ef_search = read_struct(f_in, "<i")
|
||||||
|
dummy_upper_beam = read_struct(f_in, "<i")
|
||||||
|
f_out.write(struct.pack("<i", entry_point))
|
||||||
|
f_out.write(struct.pack("<i", max_level))
|
||||||
|
f_out.write(struct.pack("<i", ef_construction))
|
||||||
|
f_out.write(struct.pack("<i", ef_search))
|
||||||
|
f_out.write(struct.pack("<i", dummy_upper_beam))
|
||||||
|
|
||||||
|
# Storage fourcc (if present) — write NULL marker and drop any remaining data
|
||||||
|
try:
|
||||||
|
read_struct(f_in, "<I")
|
||||||
|
# Regardless of original, write NULL
|
||||||
|
f_out.write(struct.pack("<I", NULL_INDEX_FOURCC))
|
||||||
|
# Discard the rest of the file (embedding payload)
|
||||||
|
# (Do not copy anything else)
|
||||||
|
except EOFError:
|
||||||
|
# No storage section; nothing else to write
|
||||||
|
pass
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
# Best-effort cleanup
|
||||||
|
try:
|
||||||
|
if out_path.exists():
|
||||||
|
out_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def prune_embeddings_preserve_graph_inplace(index_file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Convenience wrapper: write pruned file to a temporary path next to the
|
||||||
|
original, then atomically replace on success.
|
||||||
|
"""
|
||||||
|
print(f"Pruning embeddings from {index_file_path} to {index_file_path}")
|
||||||
|
print("--------------------------------")
|
||||||
|
# running in mode is-recompute=True and is-compact=False
|
||||||
|
src = Path(index_file_path)
|
||||||
|
tmp = src.with_suffix(".pruned.tmp")
|
||||||
|
ok = prune_embeddings_preserve_graph(str(src), str(tmp))
|
||||||
|
if not ok:
|
||||||
|
if tmp.exists():
|
||||||
|
try:
|
||||||
|
tmp.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
os.replace(str(tmp), str(src))
|
||||||
|
except Exception:
|
||||||
|
# Rollback on failure
|
||||||
|
try:
|
||||||
|
if tmp.exists():
|
||||||
|
tmp.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
return True
|
||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.3.2",
|
"leann-core==0.3.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 4a2c0d67d3...ea86d06ceb
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -5,19 +5,24 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from leann.interface import LeannBackendSearcherInterface
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
|
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
|
from .embedding_server_manager import EmbeddingServerManager
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
|
from .metadata_filter import MetadataFilterEngine
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -115,6 +120,20 @@ class SearchResult:
|
|||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IncrementalAddContext:
|
||||||
|
"""Prepared context for safe incremental add operations on an index."""
|
||||||
|
|
||||||
|
index_path: str
|
||||||
|
passages_file: Path
|
||||||
|
offsets_file: Path
|
||||||
|
vector_index_file: Path
|
||||||
|
embedding_model: str
|
||||||
|
embedding_mode: str
|
||||||
|
distance_metric: str
|
||||||
|
prepared_texts: Optional[list[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class PassageManager:
|
class PassageManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||||
@@ -125,6 +144,7 @@ class PassageManager:
|
|||||||
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
|
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
|
||||||
# per-shard maps and do a lightweight per-shard lookup on demand.
|
# per-shard maps and do a lightweight per-shard lookup on demand.
|
||||||
self._total_count: int = 0
|
self._total_count: int = 0
|
||||||
|
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
|
||||||
|
|
||||||
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
||||||
index_name_base = None
|
index_name_base = None
|
||||||
@@ -212,6 +232,56 @@ class PassageManager:
|
|||||||
continue
|
continue
|
||||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
|
|
||||||
|
def filter_search_results(
|
||||||
|
self,
|
||||||
|
search_results: list[SearchResult],
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Apply metadata filters to search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of SearchResult objects
|
||||||
|
metadata_filters: Filter specifications to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of SearchResult objects
|
||||||
|
"""
|
||||||
|
if not metadata_filters:
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
logger.debug(f"Applying metadata filters to {len(search_results)} results")
|
||||||
|
|
||||||
|
# Convert SearchResult objects to dictionaries for the filter engine
|
||||||
|
result_dicts = []
|
||||||
|
for result in search_results:
|
||||||
|
result_dicts.append(
|
||||||
|
{
|
||||||
|
"id": result.id,
|
||||||
|
"score": result.score,
|
||||||
|
"text": result.text,
|
||||||
|
"metadata": result.metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters using the filter engine
|
||||||
|
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
|
||||||
|
|
||||||
|
# Convert back to SearchResult objects
|
||||||
|
filtered_results = []
|
||||||
|
for result_dict in filtered_dicts:
|
||||||
|
filtered_results.append(
|
||||||
|
SearchResult(
|
||||||
|
id=result_dict["id"],
|
||||||
|
score=result_dict["score"],
|
||||||
|
text=result_dict["text"],
|
||||||
|
metadata=result_dict["metadata"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return self._total_count
|
return self._total_count
|
||||||
|
|
||||||
@@ -422,9 +492,7 @@ class LeannBuilder:
|
|||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||||
meta_data["is_compact"] = is_compact
|
meta_data["is_compact"] = is_compact
|
||||||
meta_data["is_pruned"] = (
|
meta_data["is_pruned"] = is_recompute # Pruned only if compact and recompute
|
||||||
is_compact and is_recompute
|
|
||||||
) # Pruned only if compact and recompute
|
|
||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
@@ -578,6 +646,8 @@ class LeannSearcher:
|
|||||||
self.passage_manager = PassageManager(
|
self.passage_manager = PassageManager(
|
||||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||||
)
|
)
|
||||||
|
# Preserve backend name for conditional parameter forwarding
|
||||||
|
self.backend_name = backend_name
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
@@ -597,11 +667,43 @@ class LeannSearcher:
|
|||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search for nearest neighbors with optional metadata filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text query to search for
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||||
|
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
|
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||||
|
Format: {"field_name": {"operator": value}}
|
||||||
|
Supported operators:
|
||||||
|
- Comparison: "==", "!=", "<", "<=", ">", ">="
|
||||||
|
- Membership: "in", "not_in"
|
||||||
|
- String: "contains", "starts_with", "ends_with"
|
||||||
|
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
|
||||||
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects with text, metadata, and similarity scores
|
||||||
|
"""
|
||||||
|
# Handle grep search
|
||||||
|
if use_grep:
|
||||||
|
return self._grep_search(query, top_k)
|
||||||
|
|
||||||
logger.info("🔍 LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
logger.info(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
logger.info(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
|
logger.info(f" Metadata filters: {metadata_filters}")
|
||||||
logger.info(f" Additional kwargs: {kwargs}")
|
logger.info(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
# Smart top_k detection and adjustment
|
# Smart top_k detection and adjustment
|
||||||
@@ -636,23 +738,33 @@ class LeannSearcher:
|
|||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
# time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
# logger.info(f" Embedding time: {embedding_time} seconds")
|
logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
backend_search_kwargs: dict[str, Any] = {
|
||||||
|
"complexity": complexity,
|
||||||
|
"beam_width": beam_width,
|
||||||
|
"prune_ratio": prune_ratio,
|
||||||
|
"recompute_embeddings": recompute_embeddings,
|
||||||
|
"pruning_strategy": pruning_strategy,
|
||||||
|
"zmq_port": zmq_port,
|
||||||
|
}
|
||||||
|
# Only HNSW supports batching; forward conditionally
|
||||||
|
if self.backend_name == "hnsw":
|
||||||
|
backend_search_kwargs["batch_size"] = batch_size
|
||||||
|
|
||||||
|
# Merge any extra kwargs last
|
||||||
|
backend_search_kwargs.update(kwargs)
|
||||||
|
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
query_embedding,
|
query_embedding,
|
||||||
top_k,
|
top_k,
|
||||||
complexity=complexity,
|
**backend_search_kwargs,
|
||||||
beam_width=beam_width,
|
|
||||||
prune_ratio=prune_ratio,
|
|
||||||
recompute_embeddings=recompute_embeddings,
|
|
||||||
pruning_strategy=pruning_strategy,
|
|
||||||
zmq_port=zmq_port,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
search_time = time.time() - start_time
|
||||||
|
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
|
||||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
@@ -691,15 +803,109 @@ class LeannSearcher:
|
|||||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply metadata filters if specified
|
||||||
|
if metadata_filters:
|
||||||
|
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
|
||||||
|
enriched_results = self.passage_manager.filter_search_results(
|
||||||
|
enriched_results, metadata_filters
|
||||||
|
)
|
||||||
|
|
||||||
# Define color codes outside the loop for final message
|
# Define color codes outside the loop for final message
|
||||||
GREEN = "\033[92m"
|
GREEN = "\033[92m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
def _find_jsonl_file(self) -> Optional[str]:
|
||||||
|
"""Find the .jsonl file containing raw passages for grep search"""
|
||||||
|
index_path = Path(self.meta_path_str).parent
|
||||||
|
potential_files = [
|
||||||
|
index_path / "documents.leann.passages.jsonl",
|
||||||
|
index_path.parent / "documents.leann.passages.jsonl",
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_path in potential_files:
|
||||||
|
if file_path.exists():
|
||||||
|
return str(file_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Perform grep-based search on raw passages"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl passages file found for grep search")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = ["grep", "-i", "-n", query, jsonl_file]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||||
|
|
||||||
|
if result.returncode == 1:
|
||||||
|
return []
|
||||||
|
elif result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Grep failed: {result.stderr}")
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for line in result.stdout.strip().split("\n"):
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split(":", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(parts[1])
|
||||||
|
text = data.get("text", "")
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", parts[0]),
|
||||||
|
text=text,
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(score),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"grep command not found. Please install grep or use semantic search."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Fallback regex search"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl file found")
|
||||||
|
|
||||||
|
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
with open(jsonl_file, encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
if pattern.search(line):
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", str(line_num)),
|
||||||
|
text=data.get("text", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(len(pattern.findall(data.get("text", "")))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Explicitly cleanup embedding server resources.
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
|
||||||
This method should be called after you're done using the searcher,
|
This method should be called after you're done using the searcher,
|
||||||
especially in test environments or batch processing scenarios.
|
especially in test environments or batch processing scenarios.
|
||||||
"""
|
"""
|
||||||
@@ -731,9 +937,15 @@ class LeannChat:
|
|||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: Optional[dict[str, Any]] = None,
|
llm_config: Optional[dict[str, Any]] = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
|
searcher: Optional[LeannSearcher] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
if searcher is None:
|
||||||
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||||
|
self._owns_searcher = True
|
||||||
|
else:
|
||||||
|
self.searcher = searcher
|
||||||
|
self._owns_searcher = False
|
||||||
self.llm = get_llm(llm_config)
|
self.llm = get_llm(llm_config)
|
||||||
|
|
||||||
def ask(
|
def ask(
|
||||||
@@ -747,6 +959,9 @@ class LeannChat:
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
@@ -761,10 +976,12 @@ class LeannChat:
|
|||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
|
metadata_filters=metadata_filters,
|
||||||
|
batch_size=batch_size,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - search_time
|
search_time = time.time() - search_time
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
logger.info(f" Search time: {search_time} seconds")
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
@@ -800,7 +1017,9 @@ class LeannChat:
|
|||||||
This method should be called after you're done using the chat interface,
|
This method should be called after you're done using the chat interface,
|
||||||
especially in test environments or batch processing scenarios.
|
especially in test environments or batch processing scenarios.
|
||||||
"""
|
"""
|
||||||
if hasattr(self.searcher, "cleanup"):
|
# Only stop the embedding server if this LeannChat instance created the searcher.
|
||||||
|
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
|
||||||
|
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
|
||||||
self.searcher.cleanup()
|
self.searcher.cleanup()
|
||||||
|
|
||||||
# Enable automatic cleanup patterns
|
# Enable automatic cleanup patterns
|
||||||
@@ -813,8 +1032,405 @@ class LeannChat:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
|
# ------------------------------
|
||||||
|
# Incremental Add Utilities (HNSW no-recompute only)
|
||||||
|
# ------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_index_paths(index_path: str) -> tuple[Path, Path, Path]:
|
||||||
|
"""Given base index path (without extension), return (passages.jsonl, passages.idx, vector.index).
|
||||||
|
|
||||||
|
For HNSW, vector index file is typically <stem>.index (e.g., documents.index) even when base is
|
||||||
|
'documents.leann'. We prefer an existing <stem>.index, otherwise fall back to <name>.index.
|
||||||
|
"""
|
||||||
|
base = Path(index_path)
|
||||||
|
passages_file = base.parent / f"{base.name}.passages.jsonl"
|
||||||
|
offsets_file = base.parent / f"{base.name}.passages.idx"
|
||||||
|
candidate_name_index = base.parent / f"{base.name}.index"
|
||||||
|
candidate_stem_index = base.parent / f"{base.stem}.index"
|
||||||
|
vector_index_file = (
|
||||||
|
candidate_stem_index if candidate_stem_index.exists() else candidate_name_index
|
||||||
|
)
|
||||||
|
return passages_file, offsets_file, vector_index_file
|
||||||
|
|
||||||
|
|
||||||
|
def _read_meta_file(index_path: str) -> dict[str, Any]:
|
||||||
|
meta_path = Path(f"{index_path}.meta.json")
|
||||||
|
if not meta_path.exists():
|
||||||
|
raise FileNotFoundError(f"Leann metadata file not found: {meta_path}")
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_offset_map_pickle(offsets_file: Path) -> dict[str, int]:
|
||||||
|
if not offsets_file.exists():
|
||||||
|
return {}
|
||||||
|
with open(offsets_file, "rb") as f:
|
||||||
|
return pickle.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _append_passages_and_update_offsets(
|
||||||
|
passages_file: Path, offsets_file: Path, new_texts: list[str]
|
||||||
|
) -> list[str]:
|
||||||
|
"""Append new texts to passages file, update offset map, and return assigned string IDs.
|
||||||
|
|
||||||
|
IDs are assigned as incrementing integers based on existing keys in the offset map.
|
||||||
|
"""
|
||||||
|
offset_map = _load_offset_map_pickle(offsets_file)
|
||||||
|
# Compute next numeric id
|
||||||
|
numeric_ids = [int(x) for x in offset_map.keys() if str(x).isdigit()]
|
||||||
|
next_id_num = (max(numeric_ids) + 1) if numeric_ids else 0
|
||||||
|
assigned_ids: list[str] = []
|
||||||
|
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for text in new_texts:
|
||||||
|
offset = f.tell()
|
||||||
|
str_id = str(next_id_num)
|
||||||
|
json.dump({"id": str_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[str_id] = offset
|
||||||
|
assigned_ids.append(str_id)
|
||||||
|
next_id_num += 1
|
||||||
|
|
||||||
|
with open(offsets_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
return assigned_ids
|
||||||
|
|
||||||
|
|
||||||
|
def incremental_add_texts(
|
||||||
|
index_path: str,
|
||||||
|
texts: list[str],
|
||||||
|
*,
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
embedding_mode: Optional[str] = None,
|
||||||
|
ef_construction: Optional[int] = None,
|
||||||
|
recompute: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Incrementally add text chunks to an existing HNSW index built with no-recompute.
|
||||||
|
|
||||||
|
- Validates backend is HNSW and index is non-compact (no-recompute path)
|
||||||
|
- Appends passages and offsets
|
||||||
|
- Computes embeddings and appends to the HNSW vector index
|
||||||
|
|
||||||
|
Returns number of added chunks.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
meta = _read_meta_file(index_path)
|
||||||
|
if meta.get("backend_name") != "hnsw":
|
||||||
|
raise RuntimeError("Incremental add is currently supported only for HNSW backend")
|
||||||
|
if meta.get("is_compact", True):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Index is compact/pruned. Rebuild base with is_recompute=False and is_compact=False for incremental add."
|
||||||
|
)
|
||||||
|
|
||||||
|
passages_file, offsets_file, vector_index_file = _resolve_index_paths(index_path)
|
||||||
|
if not vector_index_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Vector index file missing: {vector_index_file}. Build base first with LeannBuilder."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resolve embedding config from meta if not provided
|
||||||
|
model_name = embedding_model or meta.get("embedding_model", "facebook/contriever")
|
||||||
|
mode_name = embedding_mode or meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
|
||||||
|
# Append passages and update offsets
|
||||||
|
assigned_ids = _append_passages_and_update_offsets(passages_file, offsets_file, texts)
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
|
# Embedding computation path
|
||||||
|
esm = None
|
||||||
|
port = None
|
||||||
|
if recompute:
|
||||||
|
# Determine distance metric early for server config
|
||||||
|
distance_metric = meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||||
|
# Start embedding server and compute via ZMQ for consistency with recompute semantics
|
||||||
|
passages_source_file = f"{index_path}.meta.json"
|
||||||
|
esm = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
|
||||||
|
)
|
||||||
|
started, port = esm.start_server(
|
||||||
|
port=5557,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=mode_name,
|
||||||
|
passages_file=passages_source_file,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
enable_warmup=False,
|
||||||
|
)
|
||||||
|
if not started:
|
||||||
|
raise RuntimeError("Failed to start embedding server for recompute add")
|
||||||
|
embeddings = compute_embeddings_via_server(texts, model_name, port)
|
||||||
|
else:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name=model_name,
|
||||||
|
mode=mode_name,
|
||||||
|
use_server=False,
|
||||||
|
is_build=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize for cosine if needed
|
||||||
|
if "distance_metric" not in locals():
|
||||||
|
distance_metric = meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
# Append via backend helper (supports ef_construction/recompute plumbing)
|
||||||
|
try:
|
||||||
|
from leann_backend_hnsw.hnsw_backend import add_vectors as hnsw_add_vectors # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to import HNSW backend add helper. Ensure HNSW backend is installed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Propagate ZMQ port to FAISS add path when recompute is True
|
||||||
|
if recompute and port is not None:
|
||||||
|
os.environ["LEANN_ZMQ_PORT"] = str(port)
|
||||||
|
|
||||||
|
hnsw_add_vectors(
|
||||||
|
str(vector_index_file),
|
||||||
|
embeddings,
|
||||||
|
ef_construction=ef_construction,
|
||||||
|
recompute=recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop server after add when recompute path used
|
||||||
|
if esm is not None:
|
||||||
try:
|
try:
|
||||||
self.cleanup()
|
esm.stop_server()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Sanity: ids length should match embeddings rows
|
||||||
|
if len(assigned_ids) != embeddings.shape[0]:
|
||||||
|
warnings.warn(
|
||||||
|
f"Assigned {len(assigned_ids)} IDs but computed {embeddings.shape[0]} embeddings.",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(assigned_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def create_incremental_add_context(
|
||||||
|
index_path: str,
|
||||||
|
*,
|
||||||
|
# Optional embedding choices; if None will use meta
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
embedding_mode: Optional[str] = None,
|
||||||
|
# Optional data-to-text preparation in context
|
||||||
|
data_dir: Optional[str] = None,
|
||||||
|
required_exts: Optional[list[str]] = None,
|
||||||
|
chunk_size: int = 256,
|
||||||
|
chunk_overlap: int = 128,
|
||||||
|
max_items: int = -1,
|
||||||
|
) -> IncrementalAddContext:
|
||||||
|
"""Validate index and prepare context for repeated incremental adds.
|
||||||
|
|
||||||
|
Additionally, if data_dir is provided, this function will load documents,
|
||||||
|
chunk them to texts with the specified parameters, and store them in ctx.prepared_texts.
|
||||||
|
"""
|
||||||
|
meta = _read_meta_file(index_path)
|
||||||
|
if meta.get("backend_name") != "hnsw":
|
||||||
|
raise RuntimeError("Incremental add is currently supported only for HNSW backend")
|
||||||
|
if meta.get("is_compact", True):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Index is compact/pruned. Rebuild base with is_recompute=False and is_compact=False for incremental add."
|
||||||
|
)
|
||||||
|
|
||||||
|
passages_file, offsets_file, vector_index_file = _resolve_index_paths(index_path)
|
||||||
|
if not vector_index_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Vector index file missing: {vector_index_file}. Build base first with LeannBuilder."
|
||||||
|
)
|
||||||
|
|
||||||
|
model_name = embedding_model or meta.get("embedding_model", "facebook/contriever")
|
||||||
|
mode_name = embedding_mode or meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
distance_metric = meta.get("backend_kwargs", {}).get("distance_metric", "mips").lower()
|
||||||
|
|
||||||
|
prepared_texts: Optional[list[str]] = None
|
||||||
|
if data_dir is not None:
|
||||||
|
try:
|
||||||
|
from llama_index.core import SimpleDirectoryReader # type: ignore
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"llama-index-core is required when using data_dir in create_incremental_add_context"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
reader_kwargs: dict[str, Any] = {"recursive": True, "encoding": "utf-8"}
|
||||||
|
if required_exts:
|
||||||
|
reader_kwargs["required_exts"] = required_exts
|
||||||
|
documents = SimpleDirectoryReader(data_dir, **reader_kwargs).load_data(show_progress=True)
|
||||||
|
if documents:
|
||||||
|
splitter = SentenceSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separator=" ",
|
||||||
|
paragraph_separator="\n\n",
|
||||||
|
)
|
||||||
|
prepared_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
try:
|
||||||
|
nodes = splitter.get_nodes_from_documents([doc])
|
||||||
|
if nodes:
|
||||||
|
prepared_texts.extend([node.get_content() for node in nodes])
|
||||||
|
except Exception:
|
||||||
|
content = doc.get_content()
|
||||||
|
if content and content.strip():
|
||||||
|
prepared_texts.append(content.strip())
|
||||||
|
if max_items > 0 and len(prepared_texts) > max_items:
|
||||||
|
prepared_texts = prepared_texts[:max_items]
|
||||||
|
|
||||||
|
return IncrementalAddContext(
|
||||||
|
index_path=index_path,
|
||||||
|
passages_file=passages_file,
|
||||||
|
offsets_file=offsets_file,
|
||||||
|
vector_index_file=vector_index_file,
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=mode_name,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
prepared_texts=prepared_texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def incremental_add_texts_with_context(
|
||||||
|
ctx: IncrementalAddContext,
|
||||||
|
texts: list[str],
|
||||||
|
*,
|
||||||
|
ef_construction: Optional[int] = None,
|
||||||
|
recompute: bool = False,
|
||||||
|
) -> int:
|
||||||
|
"""Incrementally add texts using a prepared context (no repeated validation).
|
||||||
|
|
||||||
|
For non-compact HNSW, ef_construction (efConstruction) can be overridden during insertion.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Append passages & offsets
|
||||||
|
_append_passages_and_update_offsets(ctx.passages_file, ctx.offsets_file, texts)
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
|
# Embedding computation path
|
||||||
|
esm = None
|
||||||
|
port = None
|
||||||
|
if recompute:
|
||||||
|
passages_source_file = f"{ctx.index_path}.meta.json"
|
||||||
|
esm = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server",
|
||||||
|
)
|
||||||
|
started, port = esm.start_server(
|
||||||
|
port=5557,
|
||||||
|
model_name=ctx.embedding_model,
|
||||||
|
embedding_mode=ctx.embedding_mode,
|
||||||
|
passages_file=passages_source_file,
|
||||||
|
distance_metric=ctx.distance_metric,
|
||||||
|
enable_warmup=False,
|
||||||
|
)
|
||||||
|
if not started:
|
||||||
|
raise RuntimeError("Failed to start embedding server for recompute add")
|
||||||
|
embeddings = compute_embeddings_via_server(texts, ctx.embedding_model, port)
|
||||||
|
else:
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name=ctx.embedding_model,
|
||||||
|
mode=ctx.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
is_build=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize for cosine if needed
|
||||||
|
if ctx.distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
# Append via backend helper (supports ef_construction/recompute plumbing)
|
||||||
|
try:
|
||||||
|
from leann_backend_hnsw.hnsw_backend import add_vectors as hnsw_add_vectors # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to import HNSW backend add helper. Ensure HNSW backend is installed."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if recompute and port is not None:
|
||||||
|
os.environ["LEANN_ZMQ_PORT"] = str(port)
|
||||||
|
|
||||||
|
hnsw_add_vectors(
|
||||||
|
str(ctx.vector_index_file),
|
||||||
|
embeddings,
|
||||||
|
ef_construction=ef_construction,
|
||||||
|
recompute=recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop server after add when recompute path used
|
||||||
|
if esm is not None:
|
||||||
|
try:
|
||||||
|
esm.stop_server()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return embeddings.shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
def incremental_add_directory(
|
||||||
|
index_path: str,
|
||||||
|
data_dir: str,
|
||||||
|
*,
|
||||||
|
chunk_size: int = 256,
|
||||||
|
chunk_overlap: int = 128,
|
||||||
|
required_exts: Optional[list[str]] = None,
|
||||||
|
max_items: int = -1,
|
||||||
|
embedding_model: Optional[str] = None,
|
||||||
|
embedding_mode: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Load documents from a directory, chunk them, and incrementally add to an index.
|
||||||
|
|
||||||
|
Chunking uses LlamaIndex SentenceSplitter for simplicity and avoids external app dependencies.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from llama_index.core import SimpleDirectoryReader # type: ignore
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("llama-index-core is required for incremental_add_directory") from e
|
||||||
|
|
||||||
|
reader_kwargs: dict[str, Any] = {"recursive": True, "encoding": "utf-8"}
|
||||||
|
if required_exts:
|
||||||
|
reader_kwargs["required_exts"] = required_exts
|
||||||
|
documents = SimpleDirectoryReader(data_dir, **reader_kwargs).load_data(show_progress=True)
|
||||||
|
if not documents:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Traditional text chunking
|
||||||
|
splitter = SentenceSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separator=" ",
|
||||||
|
paragraph_separator="\n\n",
|
||||||
|
)
|
||||||
|
all_texts: list[str] = []
|
||||||
|
for doc in documents:
|
||||||
|
try:
|
||||||
|
nodes = splitter.get_nodes_from_documents([doc])
|
||||||
|
if nodes:
|
||||||
|
all_texts.extend([node.get_content() for node in nodes])
|
||||||
|
except Exception:
|
||||||
|
content = doc.get_content()
|
||||||
|
if content and content.strip():
|
||||||
|
all_texts.append(content.strip())
|
||||||
|
|
||||||
|
if max_items > 0 and len(all_texts) > max_items:
|
||||||
|
all_texts = all_texts[:max_items]
|
||||||
|
|
||||||
|
return incremental_add_texts(
|
||||||
|
index_path,
|
||||||
|
all_texts,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Enhanced chunking utilities with AST-aware code chunking support.
|
Enhanced chunking utilities with AST-aware code chunking support.
|
||||||
Provides unified interface for both traditional and AST-based text chunking.
|
Packaged within leann-core so installed wheels can import it reliably.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -22,30 +22,9 @@ CODE_EXTENSIONS = {
|
|||||||
".jsx": "typescript",
|
".jsx": "typescript",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Default chunk parameters for different content types
|
|
||||||
DEFAULT_CHUNK_PARAMS = {
|
|
||||||
"code": {
|
|
||||||
"max_chunk_size": 512,
|
|
||||||
"chunk_overlap": 64,
|
|
||||||
},
|
|
||||||
"text": {
|
|
||||||
"chunk_size": 256,
|
|
||||||
"chunk_overlap": 128,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
||||||
"""
|
"""Separate documents into code files and regular text files."""
|
||||||
Separate documents into code files and regular text files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of LlamaIndex Document objects
|
|
||||||
code_extensions: Dict mapping file extensions to languages (defaults to CODE_EXTENSIONS)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (code_documents, text_documents)
|
|
||||||
"""
|
|
||||||
if code_extensions is None:
|
if code_extensions is None:
|
||||||
code_extensions = CODE_EXTENSIONS
|
code_extensions = CODE_EXTENSIONS
|
||||||
|
|
||||||
@@ -53,16 +32,10 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
|||||||
text_docs = []
|
text_docs = []
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
# Get file path from metadata
|
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
|
||||||
file_path = doc.metadata.get("file_path", "")
|
|
||||||
if not file_path:
|
|
||||||
# Fallback to file_name
|
|
||||||
file_path = doc.metadata.get("file_name", "")
|
|
||||||
|
|
||||||
if file_path:
|
if file_path:
|
||||||
file_ext = Path(file_path).suffix.lower()
|
file_ext = Path(file_path).suffix.lower()
|
||||||
if file_ext in code_extensions:
|
if file_ext in code_extensions:
|
||||||
# Add language info to metadata
|
|
||||||
doc.metadata["language"] = code_extensions[file_ext]
|
doc.metadata["language"] = code_extensions[file_ext]
|
||||||
doc.metadata["is_code"] = True
|
doc.metadata["is_code"] = True
|
||||||
code_docs.append(doc)
|
code_docs.append(doc)
|
||||||
@@ -70,7 +43,6 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
|||||||
doc.metadata["is_code"] = False
|
doc.metadata["is_code"] = False
|
||||||
text_docs.append(doc)
|
text_docs.append(doc)
|
||||||
else:
|
else:
|
||||||
# If no file path, treat as text
|
|
||||||
doc.metadata["is_code"] = False
|
doc.metadata["is_code"] = False
|
||||||
text_docs.append(doc)
|
text_docs.append(doc)
|
||||||
|
|
||||||
@@ -79,7 +51,7 @@ def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
|||||||
|
|
||||||
|
|
||||||
def get_language_from_extension(file_path: str) -> Optional[str]:
|
def get_language_from_extension(file_path: str) -> Optional[str]:
|
||||||
"""Get the programming language from file extension."""
|
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
|
||||||
ext = Path(file_path).suffix.lower()
|
ext = Path(file_path).suffix.lower()
|
||||||
return CODE_EXTENSIONS.get(ext)
|
return CODE_EXTENSIONS.get(ext)
|
||||||
|
|
||||||
@@ -90,40 +62,26 @@ def create_ast_chunks(
|
|||||||
chunk_overlap: int = 64,
|
chunk_overlap: int = 64,
|
||||||
metadata_template: str = "default",
|
metadata_template: str = "default",
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""Create AST-aware chunks from code documents using astchunk.
|
||||||
Create AST-aware chunks from code documents using astchunk.
|
|
||||||
|
|
||||||
Args:
|
Falls back to traditional chunking if astchunk is unavailable.
|
||||||
documents: List of code documents
|
|
||||||
max_chunk_size: Maximum characters per chunk
|
|
||||||
chunk_overlap: Number of AST nodes to overlap between chunks
|
|
||||||
metadata_template: Template for chunk metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks with preserved code structure
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from astchunk import ASTChunkBuilder
|
from astchunk import ASTChunkBuilder # optional dependency
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"astchunk not available: {e}")
|
logger.error(f"astchunk not available: {e}")
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
logger.info("Falling back to traditional chunking for code files")
|
||||||
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||||
|
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
# Get language from metadata (set by detect_code_files)
|
|
||||||
language = doc.metadata.get("language")
|
language = doc.metadata.get("language")
|
||||||
if not language:
|
if not language:
|
||||||
logger.warning(
|
logger.warning("No language detected; falling back to traditional chunking")
|
||||||
"No language detected for document, falling back to traditional chunking"
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
)
|
|
||||||
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
|
|
||||||
all_chunks.extend(traditional_chunks)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Configure astchunk
|
|
||||||
configs = {
|
configs = {
|
||||||
"max_chunk_size": max_chunk_size,
|
"max_chunk_size": max_chunk_size,
|
||||||
"language": language,
|
"language": language,
|
||||||
@@ -131,7 +89,6 @@ def create_ast_chunks(
|
|||||||
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add repository-level metadata if available
|
|
||||||
repo_metadata = {
|
repo_metadata = {
|
||||||
"file_path": doc.metadata.get("file_path", ""),
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
"file_name": doc.metadata.get("file_name", ""),
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
@@ -140,17 +97,13 @@ def create_ast_chunks(
|
|||||||
}
|
}
|
||||||
configs["repo_level_metadata"] = repo_metadata
|
configs["repo_level_metadata"] = repo_metadata
|
||||||
|
|
||||||
# Create chunk builder and process
|
|
||||||
chunk_builder = ASTChunkBuilder(**configs)
|
chunk_builder = ASTChunkBuilder(**configs)
|
||||||
code_content = doc.get_content()
|
code_content = doc.get_content()
|
||||||
|
|
||||||
if not code_content or not code_content.strip():
|
if not code_content or not code_content.strip():
|
||||||
logger.warning("Empty code content, skipping")
|
logger.warning("Empty code content, skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunks = chunk_builder.chunkify(code_content)
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
|
|
||||||
# Extract text content from chunks
|
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
if hasattr(chunk, "text"):
|
if hasattr(chunk, "text"):
|
||||||
chunk_text = chunk.text
|
chunk_text = chunk.text
|
||||||
@@ -159,7 +112,6 @@ def create_ast_chunks(
|
|||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
chunk_text = chunk
|
chunk_text = chunk
|
||||||
else:
|
else:
|
||||||
# Try to convert to string
|
|
||||||
chunk_text = str(chunk)
|
chunk_text = str(chunk)
|
||||||
|
|
||||||
if chunk_text and chunk_text.strip():
|
if chunk_text and chunk_text.strip():
|
||||||
@@ -168,12 +120,10 @@ def create_ast_chunks(
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"AST chunking failed for {language} file: {e}")
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||||
logger.info("Falling back to traditional chunking")
|
logger.info("Falling back to traditional chunking")
|
||||||
traditional_chunks = create_traditional_chunks([doc], max_chunk_size, chunk_overlap)
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
all_chunks.extend(traditional_chunks)
|
|
||||||
|
|
||||||
return all_chunks
|
return all_chunks
|
||||||
|
|
||||||
@@ -181,23 +131,10 @@ def create_ast_chunks(
|
|||||||
def create_traditional_chunks(
|
def create_traditional_chunks(
|
||||||
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||||
Create traditional text chunks using LlamaIndex SentenceSplitter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of documents to chunk
|
|
||||||
chunk_size: Size of each chunk in characters
|
|
||||||
chunk_overlap: Overlap between chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
# Handle invalid chunk_size values
|
|
||||||
if chunk_size <= 0:
|
if chunk_size <= 0:
|
||||||
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||||
chunk_size = 256
|
chunk_size = 256
|
||||||
|
|
||||||
# Ensure chunk_overlap is not negative and not larger than chunk_size
|
|
||||||
if chunk_overlap < 0:
|
if chunk_overlap < 0:
|
||||||
chunk_overlap = 0
|
chunk_overlap = 0
|
||||||
if chunk_overlap >= chunk_size:
|
if chunk_overlap >= chunk_size:
|
||||||
@@ -215,12 +152,9 @@ def create_traditional_chunks(
|
|||||||
try:
|
try:
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
if nodes:
|
if nodes:
|
||||||
chunk_texts = [node.get_content() for node in nodes]
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
all_texts.extend(chunk_texts)
|
|
||||||
logger.debug(f"Created {len(chunk_texts)} traditional chunks from document")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Traditional chunking failed for document: {e}")
|
logger.error(f"Traditional chunking failed for document: {e}")
|
||||||
# As last resort, add the raw content
|
|
||||||
content = doc.get_content()
|
content = doc.get_content()
|
||||||
if content and content.strip():
|
if content and content.strip():
|
||||||
all_texts.append(content.strip())
|
all_texts.append(content.strip())
|
||||||
@@ -238,32 +172,13 @@ def create_text_chunks(
|
|||||||
code_file_extensions: Optional[list[str]] = None,
|
code_file_extensions: Optional[list[str]] = None,
|
||||||
ast_fallback_traditional: bool = True,
|
ast_fallback_traditional: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
"""Create text chunks from documents with optional AST support for code files."""
|
||||||
Create text chunks from documents with optional AST support for code files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
documents: List of LlamaIndex Document objects
|
|
||||||
chunk_size: Size for traditional text chunks
|
|
||||||
chunk_overlap: Overlap for traditional text chunks
|
|
||||||
use_ast_chunking: Whether to use AST chunking for code files
|
|
||||||
ast_chunk_size: Size for AST chunks
|
|
||||||
ast_chunk_overlap: Overlap for AST chunks
|
|
||||||
code_file_extensions: Custom list of code file extensions
|
|
||||||
ast_fallback_traditional: Fall back to traditional chunking on AST errors
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of text chunks
|
|
||||||
"""
|
|
||||||
if not documents:
|
if not documents:
|
||||||
logger.warning("No documents provided for chunking")
|
logger.warning("No documents provided for chunking")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Create a local copy of supported extensions for this function call
|
|
||||||
local_code_extensions = CODE_EXTENSIONS.copy()
|
local_code_extensions = CODE_EXTENSIONS.copy()
|
||||||
|
|
||||||
# Update supported extensions if provided
|
|
||||||
if code_file_extensions:
|
if code_file_extensions:
|
||||||
# Map extensions to languages (simplified mapping)
|
|
||||||
ext_mapping = {
|
ext_mapping = {
|
||||||
".py": "python",
|
".py": "python",
|
||||||
".java": "java",
|
".java": "java",
|
||||||
@@ -273,47 +188,32 @@ def create_text_chunks(
|
|||||||
}
|
}
|
||||||
for ext in code_file_extensions:
|
for ext in code_file_extensions:
|
||||||
if ext.lower() not in local_code_extensions:
|
if ext.lower() not in local_code_extensions:
|
||||||
# Try to guess language from extension
|
|
||||||
if ext.lower() in ext_mapping:
|
if ext.lower() in ext_mapping:
|
||||||
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
||||||
|
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
|
|
||||||
if use_ast_chunking:
|
if use_ast_chunking:
|
||||||
# Separate code and text documents using local extensions
|
|
||||||
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
||||||
|
|
||||||
# Process code files with AST chunking
|
|
||||||
if code_docs:
|
if code_docs:
|
||||||
logger.info(f"Processing {len(code_docs)} code files with AST chunking")
|
|
||||||
try:
|
try:
|
||||||
ast_chunks = create_ast_chunks(
|
all_chunks.extend(
|
||||||
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
create_ast_chunks(
|
||||||
|
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
||||||
|
)
|
||||||
)
|
)
|
||||||
all_chunks.extend(ast_chunks)
|
|
||||||
logger.info(f"Created {len(ast_chunks)} AST chunks from code files")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"AST chunking failed: {e}")
|
logger.error(f"AST chunking failed: {e}")
|
||||||
if ast_fallback_traditional:
|
if ast_fallback_traditional:
|
||||||
logger.info("Falling back to traditional chunking for code files")
|
all_chunks.extend(
|
||||||
traditional_code_chunks = create_traditional_chunks(
|
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||||
code_docs, chunk_size, chunk_overlap
|
|
||||||
)
|
)
|
||||||
all_chunks.extend(traditional_code_chunks)
|
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Process text files with traditional chunking
|
|
||||||
if text_docs:
|
if text_docs:
|
||||||
logger.info(f"Processing {len(text_docs)} text files with traditional chunking")
|
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||||
text_chunks = create_traditional_chunks(text_docs, chunk_size, chunk_overlap)
|
|
||||||
all_chunks.extend(text_chunks)
|
|
||||||
logger.info(f"Created {len(text_chunks)} traditional chunks from text files")
|
|
||||||
else:
|
else:
|
||||||
# Use traditional chunking for all files
|
|
||||||
logger.info(f"Processing {len(documents)} documents with traditional chunking")
|
|
||||||
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
logger.info(f"Total chunks created: {len(all_chunks)}")
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@@ -322,9 +321,17 @@ Examples:
|
|||||||
|
|
||||||
return basic_matches
|
return basic_matches
|
||||||
|
|
||||||
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
|
def _should_exclude_file(self, file_path: Path, gitignore_matches) -> bool:
|
||||||
"""Check if a file should be excluded using gitignore parser."""
|
"""Check if a file should be excluded using gitignore parser.
|
||||||
return gitignore_matches(str(relative_path))
|
|
||||||
|
Always match against absolute, posix-style paths for consistency with
|
||||||
|
gitignore_parser expectations.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
absolute_path = file_path.resolve()
|
||||||
|
except Exception:
|
||||||
|
absolute_path = Path(str(file_path))
|
||||||
|
return gitignore_matches(absolute_path.as_posix())
|
||||||
|
|
||||||
def _is_git_submodule(self, path: Path) -> bool:
|
def _is_git_submodule(self, path: Path) -> bool:
|
||||||
"""Check if a path is a git submodule."""
|
"""Check if a path is a git submodule."""
|
||||||
@@ -396,7 +403,9 @@ Examples:
|
|||||||
print(f" {current_path}")
|
print(f" {current_path}")
|
||||||
print(" " + "─" * 45)
|
print(" " + "─" * 45)
|
||||||
|
|
||||||
current_indexes = self._discover_indexes_in_project(current_path)
|
current_indexes = self._discover_indexes_in_project(
|
||||||
|
current_path, exclude_dirs=other_projects
|
||||||
|
)
|
||||||
if current_indexes:
|
if current_indexes:
|
||||||
for idx in current_indexes:
|
for idx in current_indexes:
|
||||||
total_indexes += 1
|
total_indexes += 1
|
||||||
@@ -435,9 +444,14 @@ Examples:
|
|||||||
print(" leann build my-docs --docs ./documents")
|
print(" leann build my-docs --docs ./documents")
|
||||||
else:
|
else:
|
||||||
# Count only projects that have at least one discoverable index
|
# Count only projects that have at least one discoverable index
|
||||||
projects_count = sum(
|
projects_count = 0
|
||||||
1 for p in valid_projects if len(self._discover_indexes_in_project(p)) > 0
|
for p in valid_projects:
|
||||||
)
|
if p == current_path:
|
||||||
|
discovered = self._discover_indexes_in_project(p, exclude_dirs=other_projects)
|
||||||
|
else:
|
||||||
|
discovered = self._discover_indexes_in_project(p)
|
||||||
|
if len(discovered) > 0:
|
||||||
|
projects_count += 1
|
||||||
print(f"📊 Total: {total_indexes} indexes across {projects_count} projects")
|
print(f"📊 Total: {total_indexes} indexes across {projects_count} projects")
|
||||||
|
|
||||||
if current_indexes_count > 0:
|
if current_indexes_count > 0:
|
||||||
@@ -454,9 +468,22 @@ Examples:
|
|||||||
print("\n💡 Create your first index:")
|
print("\n💡 Create your first index:")
|
||||||
print(" leann build my-docs --docs ./documents")
|
print(" leann build my-docs --docs ./documents")
|
||||||
|
|
||||||
def _discover_indexes_in_project(self, project_path: Path):
|
def _discover_indexes_in_project(
|
||||||
"""Discover all indexes in a project directory (both CLI and apps formats)"""
|
self, project_path: Path, exclude_dirs: Optional[list[Path]] = None
|
||||||
|
):
|
||||||
|
"""Discover all indexes in a project directory (both CLI and apps formats)
|
||||||
|
|
||||||
|
exclude_dirs: when provided, skip any APP-format index files that are
|
||||||
|
located under these directories. This prevents duplicates when the
|
||||||
|
current project is a parent directory of other registered projects.
|
||||||
|
"""
|
||||||
indexes = []
|
indexes = []
|
||||||
|
exclude_dirs = exclude_dirs or []
|
||||||
|
# normalize to resolved paths once for comparison
|
||||||
|
try:
|
||||||
|
exclude_dirs_resolved = [p.resolve() for p in exclude_dirs]
|
||||||
|
except Exception:
|
||||||
|
exclude_dirs_resolved = exclude_dirs
|
||||||
|
|
||||||
# 1. CLI format: .leann/indexes/index_name/
|
# 1. CLI format: .leann/indexes/index_name/
|
||||||
cli_indexes_dir = project_path / ".leann" / "indexes"
|
cli_indexes_dir = project_path / ".leann" / "indexes"
|
||||||
@@ -495,6 +522,17 @@ Examples:
|
|||||||
continue
|
continue
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# Skip meta files that live under excluded directories
|
||||||
|
try:
|
||||||
|
meta_parent_resolved = meta_file.parent.resolve()
|
||||||
|
if any(
|
||||||
|
meta_parent_resolved.is_relative_to(ex_dir)
|
||||||
|
for ex_dir in exclude_dirs_resolved
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
# best effort; if resolve or comparison fails, do not exclude
|
||||||
|
pass
|
||||||
# Use the parent directory name as the app index display name
|
# Use the parent directory name as the app index display name
|
||||||
display_name = meta_file.parent.name
|
display_name = meta_file.parent.name
|
||||||
# Extract file base used to store files
|
# Extract file base used to store files
|
||||||
@@ -1022,7 +1060,8 @@ Examples:
|
|||||||
|
|
||||||
# Try to use better PDF parsers first, but only if PDFs are requested
|
# Try to use better PDF parsers first, but only if PDFs are requested
|
||||||
documents = []
|
documents = []
|
||||||
docs_path = Path(docs_dir)
|
# Use resolved absolute paths to avoid mismatches (symlinks, relative vs absolute)
|
||||||
|
docs_path = Path(docs_dir).resolve()
|
||||||
|
|
||||||
# Check if we should process PDFs
|
# Check if we should process PDFs
|
||||||
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
|
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
|
||||||
@@ -1031,10 +1070,15 @@ Examples:
|
|||||||
for file_path in docs_path.rglob("*.pdf"):
|
for file_path in docs_path.rglob("*.pdf"):
|
||||||
# Check if file matches any exclude pattern
|
# Check if file matches any exclude pattern
|
||||||
try:
|
try:
|
||||||
|
# Ensure both paths are resolved before computing relativity
|
||||||
|
file_path_resolved = file_path.resolve()
|
||||||
|
# Determine directory scope using the non-resolved path to avoid
|
||||||
|
# misclassifying symlinked entries as outside the docs directory
|
||||||
relative_path = file_path.relative_to(docs_path)
|
relative_path = file_path.relative_to(docs_path)
|
||||||
if not include_hidden and _path_has_hidden_segment(relative_path):
|
if not include_hidden and _path_has_hidden_segment(relative_path):
|
||||||
continue
|
continue
|
||||||
if self._should_exclude_file(relative_path, gitignore_matches):
|
# Use absolute path for gitignore matching
|
||||||
|
if self._should_exclude_file(file_path_resolved, gitignore_matches):
|
||||||
continue
|
continue
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Skip files that can't be made relative to docs_path
|
# Skip files that can't be made relative to docs_path
|
||||||
@@ -1077,10 +1121,11 @@ Examples:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""Return True if file should be included (not excluded)"""
|
"""Return True if file should be included (not excluded)"""
|
||||||
try:
|
try:
|
||||||
docs_path_obj = Path(docs_dir)
|
docs_path_obj = Path(docs_dir).resolve()
|
||||||
file_path_obj = Path(file_path)
|
file_path_obj = Path(file_path).resolve()
|
||||||
relative_path = file_path_obj.relative_to(docs_path_obj)
|
# Use absolute path for gitignore matching
|
||||||
return not self._should_exclude_file(relative_path, gitignore_matches)
|
_ = file_path_obj.relative_to(docs_path_obj) # validate scope
|
||||||
|
return not self._should_exclude_file(file_path_obj, gitignore_matches)
|
||||||
except (ValueError, OSError):
|
except (ValueError, OSError):
|
||||||
return True # Include files that can't be processed
|
return True # Include files that can't be processed
|
||||||
|
|
||||||
@@ -1170,13 +1215,8 @@ Examples:
|
|||||||
if use_ast:
|
if use_ast:
|
||||||
print("🧠 Using AST-aware chunking for code files")
|
print("🧠 Using AST-aware chunking for code files")
|
||||||
try:
|
try:
|
||||||
# Import enhanced chunking utilities
|
# Import enhanced chunking utilities from packaged module
|
||||||
# Add apps directory to path to import chunking utilities
|
from .chunking_utils import create_text_chunks
|
||||||
apps_dir = Path(__file__).parent.parent.parent.parent.parent / "apps"
|
|
||||||
if apps_dir.exists():
|
|
||||||
sys.path.insert(0, str(apps_dir))
|
|
||||||
|
|
||||||
from chunking import create_text_chunks
|
|
||||||
|
|
||||||
# Use enhanced chunking with AST support
|
# Use enhanced chunking with AST support
|
||||||
all_texts = create_text_chunks(
|
all_texts = create_text_chunks(
|
||||||
@@ -1191,7 +1231,9 @@ Examples:
|
|||||||
)
|
)
|
||||||
|
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"⚠️ AST chunking not available ({e}), falling back to traditional chunking")
|
print(
|
||||||
|
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||||
|
)
|
||||||
use_ast = False
|
use_ast = False
|
||||||
|
|
||||||
if not use_ast:
|
if not use_ast:
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -28,6 +29,8 @@ def compute_embeddings(
|
|||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -50,6 +53,8 @@ def compute_embeddings(
|
|||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
adaptive_optimization=adaptive_optimization,
|
adaptive_optimization=adaptive_optimization,
|
||||||
|
manual_tokenize=manual_tokenize,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(texts, model_name)
|
||||||
@@ -71,6 +76,8 @@ def compute_embeddings_sentence_transformers(
|
|||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||||
@@ -214,20 +221,130 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
# Compute embeddings with optimized inference mode
|
||||||
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
logger.info(
|
||||||
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
|
)
|
||||||
|
|
||||||
# Use torch.inference_mode for optimal performance
|
start_time = time.time()
|
||||||
with torch.inference_mode():
|
if not manual_tokenize:
|
||||||
embeddings = model.encode(
|
# Use SentenceTransformer's optimized encode path (default)
|
||||||
texts,
|
with torch.inference_mode():
|
||||||
batch_size=batch_size,
|
embeddings = model.encode(
|
||||||
show_progress_bar=is_build, # Don't show progress bar in server environment
|
texts,
|
||||||
convert_to_numpy=True,
|
batch_size=batch_size,
|
||||||
normalize_embeddings=False,
|
show_progress_bar=is_build, # Don't show progress bar in server environment
|
||||||
device=device,
|
convert_to_numpy=True,
|
||||||
)
|
normalize_embeddings=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
# Synchronize if CUDA to measure accurate wall time
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
||||||
|
try:
|
||||||
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
|
|
||||||
|
# Cache tokenizer and model
|
||||||
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
||||||
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
|
logger.info("Using cached HF tokenizer/model for manual path")
|
||||||
|
else:
|
||||||
|
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
|
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||||
|
hf_model.to(device)
|
||||||
|
hf_model.eval()
|
||||||
|
# Optional compile on supported devices
|
||||||
|
if device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
|
all_embeddings: list[np.ndarray] = []
|
||||||
|
# Progress bar when building or for large inputs
|
||||||
|
show_progress = is_build or len(texts) > 32
|
||||||
|
try:
|
||||||
|
if show_progress:
|
||||||
|
from tqdm import tqdm # type: ignore
|
||||||
|
|
||||||
|
batch_iter = tqdm(
|
||||||
|
range(0, len(texts), batch_size),
|
||||||
|
desc="Embedding (manual)",
|
||||||
|
unit="batch",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
except Exception:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
|
||||||
|
start_time_manual = time.time()
|
||||||
|
with torch.inference_mode():
|
||||||
|
for start_index in batch_iter:
|
||||||
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
|
batch_texts = texts[start_index:end_index]
|
||||||
|
tokenize_start_time = time.time()
|
||||||
|
inputs = hf_tokenizer(
|
||||||
|
batch_texts,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
tokenize_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
||||||
|
)
|
||||||
|
# Print shapes of all input tensors for debugging
|
||||||
|
for k, v in inputs.items():
|
||||||
|
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
||||||
|
to_device_start_time = time.time()
|
||||||
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
to_device_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
||||||
|
)
|
||||||
|
forward_start_time = time.time()
|
||||||
|
outputs = hf_model(**inputs)
|
||||||
|
forward_end_time = time.time()
|
||||||
|
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
||||||
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
|
attention_mask = inputs.get("attention_mask")
|
||||||
|
if attention_mask is None:
|
||||||
|
# Fallback: assume all tokens are valid
|
||||||
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
|
else:
|
||||||
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
|
masked = last_hidden_state * mask
|
||||||
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
|
pooled = masked.sum(dim=1) / lengths
|
||||||
|
# Move to CPU float32
|
||||||
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
|
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
|
||||||
|
end_time = time.time()
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
logger.info(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
|
||||||
# Validate results
|
# Validate results
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
|
|||||||
240
packages/leann-core/src/leann/metadata_filter.py
Normal file
240
packages/leann-core/src/leann/metadata_filter.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""
|
||||||
|
Metadata filtering engine for LEANN search results.
|
||||||
|
|
||||||
|
This module provides generic metadata filtering capabilities that can be applied
|
||||||
|
to search results from any LEANN backend. The filtering supports various
|
||||||
|
operators for different data types including numbers, strings, booleans, and lists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type alias for filter specifications
|
||||||
|
FilterValue = Union[str, int, float, bool, list]
|
||||||
|
FilterSpec = dict[str, FilterValue]
|
||||||
|
MetadataFilters = dict[str, FilterSpec]
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilterEngine:
|
||||||
|
"""
|
||||||
|
Engine for evaluating metadata filters against search results.
|
||||||
|
|
||||||
|
Supports various operators for filtering based on metadata fields:
|
||||||
|
- Comparison: ==, !=, <, <=, >, >=
|
||||||
|
- Membership: in, not_in
|
||||||
|
- String operations: contains, starts_with, ends_with
|
||||||
|
- Boolean operations: is_true, is_false
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the filter engine with supported operators."""
|
||||||
|
self.operators = {
|
||||||
|
"==": self._equals,
|
||||||
|
"!=": self._not_equals,
|
||||||
|
"<": self._less_than,
|
||||||
|
"<=": self._less_than_or_equal,
|
||||||
|
">": self._greater_than,
|
||||||
|
">=": self._greater_than_or_equal,
|
||||||
|
"in": self._in,
|
||||||
|
"not_in": self._not_in,
|
||||||
|
"contains": self._contains,
|
||||||
|
"starts_with": self._starts_with,
|
||||||
|
"ends_with": self._ends_with,
|
||||||
|
"is_true": self._is_true,
|
||||||
|
"is_false": self._is_false,
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply_filters(
|
||||||
|
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Apply metadata filters to a list of search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of result dictionaries, each containing 'metadata' field
|
||||||
|
metadata_filters: Dictionary of filter specifications
|
||||||
|
Format: {"field_name": {"operator": value}}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of search results
|
||||||
|
"""
|
||||||
|
if not metadata_filters:
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
logger.debug(f"Applying filters: {metadata_filters}")
|
||||||
|
logger.debug(f"Input results count: {len(search_results)}")
|
||||||
|
|
||||||
|
filtered_results = []
|
||||||
|
for result in search_results:
|
||||||
|
if self._evaluate_filters(result, metadata_filters):
|
||||||
|
filtered_results.append(result)
|
||||||
|
|
||||||
|
logger.debug(f"Filtered results count: {len(filtered_results)}")
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
|
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate all filters against a single search result.
|
||||||
|
|
||||||
|
All filters must pass (AND logic) for the result to be included.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Full search result dictionary (including metadata, text, etc.)
|
||||||
|
filters: Filter specifications to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all filters pass, False otherwise
|
||||||
|
"""
|
||||||
|
for field_name, filter_spec in filters.items():
|
||||||
|
if not self._evaluate_field_filter(result, field_name, filter_spec):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _evaluate_field_filter(
|
||||||
|
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate a single field filter against a search result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Full search result dictionary
|
||||||
|
field_name: Name of the field to filter on
|
||||||
|
filter_spec: Filter specification for this field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the filter passes, False otherwise
|
||||||
|
"""
|
||||||
|
# First check top-level fields, then check metadata
|
||||||
|
field_value = result.get(field_name)
|
||||||
|
if field_value is None:
|
||||||
|
# Try to get from metadata if not found at top level
|
||||||
|
metadata = result.get("metadata", {})
|
||||||
|
field_value = metadata.get(field_name)
|
||||||
|
|
||||||
|
# Handle missing fields - they fail all filters except existence checks
|
||||||
|
if field_value is None:
|
||||||
|
logger.debug(f"Field '{field_name}' not found in result or metadata")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Evaluate each operator in the filter spec
|
||||||
|
for operator, expected_value in filter_spec.items():
|
||||||
|
if operator not in self.operators:
|
||||||
|
logger.warning(f"Unsupported operator: {operator}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.operators[operator](field_value, expected_value):
|
||||||
|
logger.debug(
|
||||||
|
f"Filter failed: {field_name} {operator} {expected_value} "
|
||||||
|
f"(actual: {field_value})"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Comparison operators
|
||||||
|
def _equals(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value equals expected value."""
|
||||||
|
return field_value == expected_value
|
||||||
|
|
||||||
|
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value does not equal expected value."""
|
||||||
|
return field_value != expected_value
|
||||||
|
|
||||||
|
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is less than expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
|
||||||
|
|
||||||
|
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is less than or equal to expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
|
||||||
|
|
||||||
|
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is greater than expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
|
||||||
|
|
||||||
|
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is greater than or equal to expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
|
||||||
|
|
||||||
|
# Membership operators
|
||||||
|
def _in(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is in the expected list/collection."""
|
||||||
|
if not isinstance(expected_value, (list, tuple, set)):
|
||||||
|
raise ValueError("'in' operator requires a list, tuple, or set")
|
||||||
|
return field_value in expected_value
|
||||||
|
|
||||||
|
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is not in the expected list/collection."""
|
||||||
|
if not isinstance(expected_value, (list, tuple, set)):
|
||||||
|
raise ValueError("'not_in' operator requires a list, tuple, or set")
|
||||||
|
return field_value not in expected_value
|
||||||
|
|
||||||
|
# String operators
|
||||||
|
def _contains(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value contains the expected substring."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return expected_str in field_str
|
||||||
|
|
||||||
|
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value starts with the expected prefix."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return field_str.startswith(expected_str)
|
||||||
|
|
||||||
|
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value ends with the expected suffix."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return field_str.endswith(expected_str)
|
||||||
|
|
||||||
|
# Boolean operators
|
||||||
|
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is truthy."""
|
||||||
|
return bool(field_value)
|
||||||
|
|
||||||
|
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is falsy."""
|
||||||
|
return not bool(field_value)
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
|
||||||
|
"""
|
||||||
|
Helper for numeric comparisons with type coercion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_value: Value from metadata
|
||||||
|
expected_value: Value to compare against
|
||||||
|
compare_func: Comparison function to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of comparison
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to convert both values to numbers for comparison
|
||||||
|
if isinstance(field_value, str) and isinstance(expected_value, str):
|
||||||
|
# String comparison if both are strings
|
||||||
|
return compare_func(field_value, expected_value)
|
||||||
|
|
||||||
|
# Numeric comparison - attempt to convert to float
|
||||||
|
field_num = (
|
||||||
|
float(field_value) if not isinstance(field_value, (int, float)) else field_value
|
||||||
|
)
|
||||||
|
expected_num = (
|
||||||
|
float(expected_value)
|
||||||
|
if not isinstance(expected_value, (int, float))
|
||||||
|
else expected_value
|
||||||
|
)
|
||||||
|
|
||||||
|
return compare_func(field_num, expected_num)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Fall back to string comparison if numeric conversion fails
|
||||||
|
return compare_func(str(field_value), str(expected_value))
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
||||||
|
|
||||||
|
For agent-facing discovery details, see `llms.txt` in the repository root.
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Install LEANN globally for MCP integration (with default backend):
|
Install LEANN globally for MCP integration (with default backend):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.3.2"
|
version = "0.3.4"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ wechat-exporter = "wechat_exporter.main:main"
|
|||||||
leann-core = { path = "packages/leann-core", editable = true }
|
leann-core = { path = "packages/leann-core", editable = true }
|
||||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
astchunk = { path = "packages/astchunk-leann", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|||||||
365
tests/test_metadata_filtering.py
Normal file
365
tests/test_metadata_filtering.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""
|
||||||
|
Comprehensive tests for metadata filtering functionality.
|
||||||
|
|
||||||
|
This module tests the MetadataFilterEngine class and its integration
|
||||||
|
with the LEANN search system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Import the modules we're testing
|
||||||
|
import sys
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||||
|
|
||||||
|
from leann.api import PassageManager, SearchResult
|
||||||
|
from leann.metadata_filter import MetadataFilterEngine
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataFilterEngine:
|
||||||
|
"""Test suite for the MetadataFilterEngine class."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures."""
|
||||||
|
self.engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
# Sample search results for testing
|
||||||
|
self.sample_results = [
|
||||||
|
{
|
||||||
|
"id": "doc1",
|
||||||
|
"score": 0.95,
|
||||||
|
"text": "This is chapter 1 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"tags": ["adventure", "fantasy"],
|
||||||
|
"word_count": 150,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc2",
|
||||||
|
"score": 0.87,
|
||||||
|
"text": "This is chapter 3 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 3,
|
||||||
|
"character": "Bob",
|
||||||
|
"tags": ["mystery", "thriller"],
|
||||||
|
"word_count": 250,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc3",
|
||||||
|
"score": 0.82,
|
||||||
|
"text": "This is chapter 5 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 5,
|
||||||
|
"character": "Alice",
|
||||||
|
"tags": ["romance", "drama"],
|
||||||
|
"word_count": 300,
|
||||||
|
"is_published": False,
|
||||||
|
"genre": "non-fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc4",
|
||||||
|
"score": 0.78,
|
||||||
|
"text": "This is chapter 10 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 10,
|
||||||
|
"character": "Charlie",
|
||||||
|
"tags": ["action", "adventure"],
|
||||||
|
"word_count": 400,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_engine_initialization(self):
|
||||||
|
"""Test that the filter engine initializes correctly."""
|
||||||
|
assert self.engine is not None
|
||||||
|
assert len(self.engine.operators) > 0
|
||||||
|
assert "==" in self.engine.operators
|
||||||
|
assert "contains" in self.engine.operators
|
||||||
|
assert "in" in self.engine.operators
|
||||||
|
|
||||||
|
def test_direct_instantiation(self):
|
||||||
|
"""Test direct instantiation of the engine."""
|
||||||
|
engine = MetadataFilterEngine()
|
||||||
|
assert isinstance(engine, MetadataFilterEngine)
|
||||||
|
|
||||||
|
def test_no_filters_returns_all_results(self):
|
||||||
|
"""Test that passing None or empty filters returns all results."""
|
||||||
|
# Test with None
|
||||||
|
result = self.engine.apply_filters(self.sample_results, None)
|
||||||
|
assert len(result) == len(self.sample_results)
|
||||||
|
|
||||||
|
# Test with empty dict
|
||||||
|
result = self.engine.apply_filters(self.sample_results, {})
|
||||||
|
assert len(result) == len(self.sample_results)
|
||||||
|
|
||||||
|
# Test comparison operators
|
||||||
|
def test_equals_filter(self):
|
||||||
|
"""Test equals (==) filter."""
|
||||||
|
filters = {"chapter": {"==": 1}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["id"] == "doc1"
|
||||||
|
|
||||||
|
def test_not_equals_filter(self):
|
||||||
|
"""Test not equals (!=) filter."""
|
||||||
|
filters = {"genre": {"!=": "fiction"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["genre"] == "non-fiction"
|
||||||
|
|
||||||
|
def test_less_than_filter(self):
|
||||||
|
"""Test less than (<) filter."""
|
||||||
|
filters = {"chapter": {"<": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
chapters = [r["metadata"]["chapter"] for r in result]
|
||||||
|
assert all(ch < 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_less_than_or_equal_filter(self):
|
||||||
|
"""Test less than or equal (<=) filter."""
|
||||||
|
filters = {"chapter": {"<=": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
chapters = [r["metadata"]["chapter"] for r in result]
|
||||||
|
assert all(ch <= 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_greater_than_filter(self):
|
||||||
|
"""Test greater than (>) filter."""
|
||||||
|
filters = {"word_count": {">": 200}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3 # Documents with word_count 250, 300, 400
|
||||||
|
word_counts = [r["metadata"]["word_count"] for r in result]
|
||||||
|
assert all(wc > 200 for wc in word_counts)
|
||||||
|
|
||||||
|
def test_greater_than_or_equal_filter(self):
|
||||||
|
"""Test greater than or equal (>=) filter."""
|
||||||
|
filters = {"word_count": {">=": 250}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
word_counts = [r["metadata"]["word_count"] for r in result]
|
||||||
|
assert all(wc >= 250 for wc in word_counts)
|
||||||
|
|
||||||
|
# Test membership operators
|
||||||
|
def test_in_filter(self):
|
||||||
|
"""Test in filter."""
|
||||||
|
filters = {"character": {"in": ["Alice", "Bob"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
characters = [r["metadata"]["character"] for r in result]
|
||||||
|
assert all(ch in ["Alice", "Bob"] for ch in characters)
|
||||||
|
|
||||||
|
def test_not_in_filter(self):
|
||||||
|
"""Test not_in filter."""
|
||||||
|
filters = {"character": {"not_in": ["Alice", "Bob"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["character"] == "Charlie"
|
||||||
|
|
||||||
|
# Test string operators
|
||||||
|
def test_contains_filter(self):
|
||||||
|
"""Test contains filter."""
|
||||||
|
filters = {"genre": {"contains": "fiction"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 4 # Both "fiction" and "non-fiction"
|
||||||
|
|
||||||
|
def test_starts_with_filter(self):
|
||||||
|
"""Test starts_with filter."""
|
||||||
|
filters = {"genre": {"starts_with": "non"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["genre"] == "non-fiction"
|
||||||
|
|
||||||
|
def test_ends_with_filter(self):
|
||||||
|
"""Test ends_with filter."""
|
||||||
|
filters = {"text": {"ends_with": "content"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 4 # All sample texts end with "content"
|
||||||
|
|
||||||
|
# Test boolean operators
|
||||||
|
def test_is_true_filter(self):
|
||||||
|
"""Test is_true filter."""
|
||||||
|
filters = {"is_published": {"is_true": True}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert all(r["metadata"]["is_published"] for r in result)
|
||||||
|
|
||||||
|
def test_is_false_filter(self):
|
||||||
|
"""Test is_false filter."""
|
||||||
|
filters = {"is_published": {"is_false": False}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert not result[0]["metadata"]["is_published"]
|
||||||
|
|
||||||
|
# Test compound filters (AND logic)
|
||||||
|
def test_compound_filters(self):
|
||||||
|
"""Test multiple filters applied together (AND logic)."""
|
||||||
|
filters = {"genre": {"==": "fiction"}, "chapter": {"<=": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
assert r["metadata"]["genre"] == "fiction"
|
||||||
|
assert r["metadata"]["chapter"] <= 5
|
||||||
|
|
||||||
|
def test_multiple_operators_same_field(self):
|
||||||
|
"""Test multiple operators on the same field."""
|
||||||
|
filters = {"word_count": {">=": 200, "<=": 350}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
wc = r["metadata"]["word_count"]
|
||||||
|
assert 200 <= wc <= 350
|
||||||
|
|
||||||
|
# Test edge cases
|
||||||
|
def test_missing_field_fails_filter(self):
|
||||||
|
"""Test that missing metadata fields fail filters."""
|
||||||
|
filters = {"nonexistent_field": {"==": "value"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_invalid_operator(self):
|
||||||
|
"""Test that invalid operators are handled gracefully."""
|
||||||
|
filters = {"chapter": {"invalid_op": 1}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 0 # Should filter out all results
|
||||||
|
|
||||||
|
def test_type_coercion_numeric(self):
|
||||||
|
"""Test numeric type coercion in comparisons."""
|
||||||
|
# Add a result with string chapter number
|
||||||
|
test_results = [
|
||||||
|
*self.sample_results,
|
||||||
|
{
|
||||||
|
"id": "doc5",
|
||||||
|
"score": 0.75,
|
||||||
|
"text": "String chapter test",
|
||||||
|
"metadata": {"chapter": "2", "genre": "test"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
filters = {"chapter": {"<": 3}}
|
||||||
|
result = self.engine.apply_filters(test_results, filters)
|
||||||
|
# Should include doc1 (chapter=1) and doc5 (chapter="2")
|
||||||
|
assert len(result) == 2
|
||||||
|
ids = [r["id"] for r in result]
|
||||||
|
assert "doc1" in ids
|
||||||
|
assert "doc5" in ids
|
||||||
|
|
||||||
|
def test_list_membership_with_nested_tags(self):
|
||||||
|
"""Test membership operations with list metadata."""
|
||||||
|
# Note: This tests the metadata structure, not list field filtering
|
||||||
|
# For list field filtering, we'd need to modify the test data
|
||||||
|
filters = {"character": {"in": ["Alice"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(r["metadata"]["character"] == "Alice" for r in result)
|
||||||
|
|
||||||
|
def test_empty_results_list(self):
|
||||||
|
"""Test filtering on empty results list."""
|
||||||
|
filters = {"chapter": {"==": 1}}
|
||||||
|
result = self.engine.apply_filters([], filters)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPassageManagerFiltering:
|
||||||
|
"""Test suite for PassageManager filtering integration."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures."""
|
||||||
|
# Mock the passage manager without actual file I/O
|
||||||
|
self.passage_manager = Mock(spec=PassageManager)
|
||||||
|
self.passage_manager.filter_engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
# Sample SearchResult objects
|
||||||
|
self.search_results = [
|
||||||
|
SearchResult(
|
||||||
|
id="doc1",
|
||||||
|
score=0.95,
|
||||||
|
text="Chapter 1 content",
|
||||||
|
metadata={"chapter": 1, "character": "Alice"},
|
||||||
|
),
|
||||||
|
SearchResult(
|
||||||
|
id="doc2",
|
||||||
|
score=0.87,
|
||||||
|
text="Chapter 5 content",
|
||||||
|
metadata={"chapter": 5, "character": "Bob"},
|
||||||
|
),
|
||||||
|
SearchResult(
|
||||||
|
id="doc3",
|
||||||
|
score=0.82,
|
||||||
|
text="Chapter 10 content",
|
||||||
|
metadata={"chapter": 10, "character": "Alice"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_search_result_filtering(self):
|
||||||
|
"""Test filtering SearchResult objects."""
|
||||||
|
# Create a real PassageManager instance just for the filtering method
|
||||||
|
# We'll mock the file operations
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
filters = {"chapter": {"<=": 5}}
|
||||||
|
result = pm.filter_search_results(self.search_results, filters)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
chapters = [r.metadata["chapter"] for r in result]
|
||||||
|
assert all(ch <= 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_filter_search_results_no_filters(self):
|
||||||
|
"""Test that None filters return all results."""
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
result = pm.filter_search_results(self.search_results, None)
|
||||||
|
assert len(result) == len(self.search_results)
|
||||||
|
|
||||||
|
def test_filter_maintains_search_result_type(self):
|
||||||
|
"""Test that filtering returns SearchResult objects."""
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
filters = {"character": {"==": "Alice"}}
|
||||||
|
result = pm.filter_search_results(self.search_results, filters)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
assert isinstance(r, SearchResult)
|
||||||
|
assert r.metadata["character"] == "Alice"
|
||||||
|
|
||||||
|
|
||||||
|
# Integration tests would go here, but they require actual LEANN backend setup
|
||||||
|
# These would test the full pipeline from LeannSearcher.search() with metadata_filters
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run basic smoke tests
|
||||||
|
engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
sample_data = [
|
||||||
|
{
|
||||||
|
"id": "test1",
|
||||||
|
"score": 0.9,
|
||||||
|
"text": "Test content",
|
||||||
|
"metadata": {"chapter": 1, "published": True},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test basic filtering
|
||||||
|
result = engine.apply_filters(sample_data, {"chapter": {"==": 1}})
|
||||||
|
assert len(result) == 1
|
||||||
|
print("✅ Basic filtering test passed")
|
||||||
|
|
||||||
|
result = engine.apply_filters(sample_data, {"chapter": {"==": 2}})
|
||||||
|
assert len(result) == 0
|
||||||
|
print("✅ No match filtering test passed")
|
||||||
|
|
||||||
|
print("🎉 All smoke tests passed!")
|
||||||
Reference in New Issue
Block a user