Compare commits
28 Commits
fix/update
...
fix/ask-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
47aeb85f82 | ||
|
|
db7ba27ff6 | ||
|
|
5f7806e16f | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
4e5b73ce7b | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
a6f07a54f1 | ||
|
|
46905e0687 |
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report a bug in LEANN
|
||||||
|
labels: ["bug"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: A clear description of the bug
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to reproduce
|
||||||
|
placeholder: |
|
||||||
|
1. Install with...
|
||||||
|
2. Run command...
|
||||||
|
3. See error
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: error
|
||||||
|
attributes:
|
||||||
|
label: Error message
|
||||||
|
description: Paste any error messages
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: version
|
||||||
|
attributes:
|
||||||
|
label: LEANN Version
|
||||||
|
placeholder: "0.1.0"
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Docker
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Documentation
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||||
|
about: Read the docs first
|
||||||
|
- name: Discussions
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||||
|
about: Ask questions and share ideas
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a new feature for LEANN
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: What problem does this solve?
|
||||||
|
description: Describe the problem or need
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: How would you like this to work?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: example
|
||||||
|
attributes:
|
||||||
|
label: Example usage
|
||||||
|
description: Show how the API might look
|
||||||
|
render: python
|
||||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!-- Brief description of your changes -->
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Tests pass (`uv run pytest`)
|
||||||
|
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||||
|
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||||
106
.github/workflows/build-reusable.yml
vendored
106
.github/workflows/build-reusable.yml
vendored
@@ -54,6 +54,17 @@ jobs:
|
|||||||
python: '3.12'
|
python: '3.12'
|
||||||
- os: ubuntu-22.04
|
- os: ubuntu-22.04
|
||||||
python: '3.13'
|
python: '3.13'
|
||||||
|
# ARM64 Linux builds
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.13'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
python: '3.9'
|
python: '3.9'
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
@@ -87,7 +98,7 @@ jobs:
|
|||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
ref: ${{ inputs.ref }}
|
ref: ${{ inputs.ref }}
|
||||||
submodules: recursive
|
submodules: recursive
|
||||||
@@ -98,21 +109,56 @@ jobs:
|
|||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v4
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
- name: Install system dependencies (Ubuntu)
|
- name: Install system dependencies (Ubuntu)
|
||||||
if: runner.os == 'Linux'
|
if: runner.os == 'Linux'
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
patchelf
|
||||||
|
|
||||||
# Install Intel MKL for DiskANN
|
# Debug: Show system information
|
||||||
|
echo "🔍 System Information:"
|
||||||
|
echo "Architecture: $(uname -m)"
|
||||||
|
echo "OS: $(uname -a)"
|
||||||
|
echo "CPU info: $(lscpu | head -5)"
|
||||||
|
|
||||||
|
# Install math library based on architecture
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||||
|
|
||||||
|
if [[ "$ARCH" == "x86_64" ]]; then
|
||||||
|
# Install Intel MKL for DiskANN on x86_64
|
||||||
|
echo "📦 Installing Intel MKL for x86_64..."
|
||||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
source /opt/intel/oneapi/setvars.sh
|
source /opt/intel/oneapi/setvars.sh
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||||
|
echo "✅ Intel MKL installed for x86_64"
|
||||||
|
|
||||||
|
# Debug: Check MKL installation
|
||||||
|
echo "🔍 MKL Installation Check:"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||||
|
|
||||||
|
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||||
|
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||||
|
echo "📦 Installing OpenBLAS for ARM64..."
|
||||||
|
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||||
|
echo "✅ OpenBLAS installed for ARM64"
|
||||||
|
|
||||||
|
# Debug: Check OpenBLAS installation
|
||||||
|
echo "🔍 OpenBLAS Installation Check:"
|
||||||
|
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||||
|
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Debug: Show final library paths
|
||||||
|
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
- name: Install system dependencies (macOS)
|
||||||
if: runner.os == 'macOS'
|
if: runner.os == 'macOS'
|
||||||
@@ -304,3 +350,53 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
path: packages/*/dist/
|
path: packages/*/dist/
|
||||||
|
|
||||||
|
|
||||||
|
arch-smoke:
|
||||||
|
name: Arch Linux smoke test (install & import)
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: archlinux:latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Prepare system
|
||||||
|
run: |
|
||||||
|
pacman -Syu --noconfirm
|
||||||
|
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||||
|
|
||||||
|
- name: Download ALL wheel artifacts from this run
|
||||||
|
uses: actions/download-artifact@v5
|
||||||
|
with:
|
||||||
|
# Don't specify name, download all artifacts
|
||||||
|
path: ./wheels
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Create virtual environment and install wheels
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
uv pip install --find-links wheels leann-core
|
||||||
|
uv pip install --find-links wheels leann-backend-hnsw
|
||||||
|
uv pip install --find-links wheels leann-backend-diskann
|
||||||
|
uv pip install --find-links wheels leann
|
||||||
|
|
||||||
|
- name: Import & tiny runtime check
|
||||||
|
env:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
python - <<'PY'
|
||||||
|
import leann
|
||||||
|
import leann_backend_hnsw as h
|
||||||
|
import leann_backend_diskann as d
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
b = LeannBuilder(backend_name="hnsw")
|
||||||
|
b.add_text("hello arch")
|
||||||
|
b.build_index("arch_demo.leann")
|
||||||
|
s = LeannSearcher("arch_demo.leann")
|
||||||
|
print("search:", s.search("hello", top_k=1))
|
||||||
|
PY
|
||||||
|
|||||||
2
.github/workflows/link-check.yml
vendored
2
.github/workflows/link-check.yml
vendored
@@ -14,6 +14,6 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: lycheeverse/lychee-action@v2
|
- uses: lycheeverse/lychee-action@v2
|
||||||
with:
|
with:
|
||||||
args: --no-progress --insecure README.md docs/ apps/ examples/ benchmarks/
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -22,6 +22,7 @@ demo/experiment_results/**/*.json
|
|||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
|
!llms.txt
|
||||||
latency_breakdown*.json
|
latency_breakdown*.json
|
||||||
experiment_results/eval_results/diskann/*.json
|
experiment_results/eval_results/diskann/*.json
|
||||||
aws/
|
aws/
|
||||||
@@ -93,3 +94,10 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
batchtest.py
|
batchtest.py
|
||||||
tests/__pytest_cache__/
|
tests/__pytest_cache__/
|
||||||
tests/__pycache__/
|
tests/__pycache__/
|
||||||
|
paru-bin/
|
||||||
|
|
||||||
|
CLAUDE.md
|
||||||
|
CLAUDE.local.md
|
||||||
|
.claude/*.local.*
|
||||||
|
.claude/local/*
|
||||||
|
benchmarks/data/
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
url = https://github.com/zeromq/libzmq.git
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
[submodule "packages/astchunk-leann"]
|
||||||
|
path = packages/astchunk-leann
|
||||||
|
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||||
|
|||||||
@@ -13,4 +13,5 @@ repos:
|
|||||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
|
|||||||
130
README.md
130
README.md
@@ -8,6 +8,8 @@
|
|||||||
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
<img src="https://img.shields.io/badge/Platform-Ubuntu%20%26%20Arch%20%26%20WSL%20%7C%20macOS%20(ARM64%2FIntel)-lightgrey" alt="Platform">
|
||||||
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
<img src="https://img.shields.io/badge/License-MIT-green.svg" alt="MIT License">
|
||||||
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
<img src="https://img.shields.io/badge/MCP-Native%20Integration-blue" alt="MCP Integration">
|
||||||
|
<a href="https://join.slack.com/t/leann-e2u9779/shared_invite/zt-3ckd2f6w1-OX08~NN4gkWhh10PRVBj1Q"><img src="https://img.shields.io/badge/Slack-Join-4A154B?logo=slack&logoColor=white" alt="Join Slack">
|
||||||
|
<a href="assets/wechat_user_group.JPG" title="Join WeChat group"><img src="https://img.shields.io/badge/WeChat-Join-2DC100?logo=wechat&logoColor=white" alt="Join WeChat group"></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
<h2 align="center" tabindex="-1" class="heading-element" dir="auto">
|
||||||
@@ -87,17 +89,60 @@ git submodule update --init --recursive
|
|||||||
```
|
```
|
||||||
|
|
||||||
**macOS:**
|
**macOS:**
|
||||||
|
|
||||||
|
Note: DiskANN requires MacOS 13.3 or later.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
brew install llvm libomp boost protobuf zeromq pkgconf
|
brew install libomp boost protobuf zeromq pkgconf
|
||||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
uv sync --extra diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
**Linux:**
|
**Linux (Ubuntu/Debian):**
|
||||||
```bash
|
|
||||||
# Ubuntu/Debian (For Arch Linux: sudo pacman -S blas lapack openblas libaio boost protobuf abseil-cpp zeromq)
|
|
||||||
sudo apt-get update && sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
|
||||||
|
|
||||||
uv sync
|
Note: On Ubuntu 20.04, you may need to build a newer Abseil and pin Protobuf (e.g., v3.20.x) for building DiskANN. See [Issue #30](https://github.com/yichuan-w/LEANN/issues/30) for a step-by-step note.
|
||||||
|
|
||||||
|
You can manually install [Intel oneAPI MKL](https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl.html) instead of `libmkl-full-dev` for DiskANN. You can also use `libopenblas-dev` for building HNSW only, by removing `--extra diskann` in the command below.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo apt-get update && sudo apt-get install -y \
|
||||||
|
libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
libmkl-full-dev
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (Arch Linux):**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo pacman -Syu && sudo pacman -S --needed base-devel cmake pkgconf git gcc \
|
||||||
|
boost boost-libs protobuf abseil-cpp libaio zeromq
|
||||||
|
|
||||||
|
# For MKL in DiskANN
|
||||||
|
sudo pacman -S --needed base-devel git
|
||||||
|
git clone https://aur.archlinux.org/paru-bin.git
|
||||||
|
cd paru-bin && makepkg -si
|
||||||
|
paru -S intel-oneapi-mkl intel-oneapi-compiler
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
|
```
|
||||||
|
|
||||||
|
**Linux (RHEL / CentOS Stream / Oracle / Rocky / AlmaLinux):**
|
||||||
|
|
||||||
|
See [Issue #50](https://github.com/yichuan-w/LEANN/issues/50) for more details.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo dnf groupinstall -y "Development Tools"
|
||||||
|
sudo dnf install -y libomp-devel boost-devel protobuf-compiler protobuf-devel \
|
||||||
|
abseil-cpp-devel libaio-devel zeromq-devel pkgconf-pkg-config
|
||||||
|
|
||||||
|
# For MKL in DiskANN
|
||||||
|
sudo dnf install -y intel-oneapi-mkl intel-oneapi-mkl-devel \
|
||||||
|
intel-oneapi-openmp || sudo dnf install -y intel-oneapi-compiler
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
|
||||||
|
uv sync --extra diskann
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -133,6 +178,8 @@ response = chat.ask("How much storage does LEANN save?", top_k=1)
|
|||||||
|
|
||||||
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
LEANN supports RAG on various data sources including documents (`.pdf`, `.txt`, `.md`), Apple Mail, Google Search History, WeChat, and more.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Generation Model Setup
|
### Generation Model Setup
|
||||||
|
|
||||||
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
LEANN supports multiple LLM providers for text generation (OpenAI API, HuggingFace, Ollama).
|
||||||
@@ -175,7 +222,8 @@ ollama pull llama3.2:1b
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### ⭐ Flexible Configuration
|
|
||||||
|
## ⭐ Flexible Configuration
|
||||||
|
|
||||||
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
LEANN provides flexible parameters for embedding models, search strategies, and data processing to fit your specific needs.
|
||||||
|
|
||||||
@@ -251,6 +299,12 @@ python -m apps.document_rag --data-dir "~/Documents/Papers" --chunk-size 1024
|
|||||||
|
|
||||||
# Filter only markdown and Python files with smaller chunks
|
# Filter only markdown and Python files with smaller chunks
|
||||||
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
python -m apps.document_rag --data-dir "./docs" --chunk-size 256 --file-types .md .py
|
||||||
|
|
||||||
|
# Enable AST-aware chunking for code files
|
||||||
|
python -m apps.document_rag --enable-code-chunking --data-dir "./my_project"
|
||||||
|
|
||||||
|
# Or use the specialized code RAG for better code understanding
|
||||||
|
python -m apps.code_rag --repo-dir "./my_codebase" --query "How does authentication work?"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@@ -425,10 +479,20 @@ Once the index is built, you can ask questions like:
|
|||||||
|
|
||||||
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
### 🚀 Claude Code Integration: Transform Your Development Workflow!
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>NEW!! AST‑Aware Code Chunking</strong></summary>
|
||||||
|
|
||||||
|
LEANN features intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript, improving code understanding compared to text-based chunking.
|
||||||
|
|
||||||
|
📖 Read the [AST Chunking Guide →](docs/ast_chunking_guide.md)
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
**The future of code assistance is here.** Transform your development workflow with LEANN's native MCP integration for Claude Code. Index your entire codebase and get intelligent code assistance directly in your IDE.
|
||||||
|
|
||||||
**Key features:**
|
**Key features:**
|
||||||
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
|
- 🔍 **Semantic code search** across your entire project, fully local index and lightweight
|
||||||
|
- 🧠 **AST-aware chunking** preserves code structure (functions, classes)
|
||||||
- 📚 **Context-aware assistance** for debugging and development
|
- 📚 **Context-aware assistance** for debugging and development
|
||||||
- 🚀 **Zero-config setup** with automatic language detection
|
- 🚀 **Zero-config setup** with automatic language detection
|
||||||
|
|
||||||
@@ -482,6 +546,9 @@ leann search my-docs "machine learning concepts"
|
|||||||
# Interactive chat with your documents
|
# Interactive chat with your documents
|
||||||
leann ask my-docs --interactive
|
leann ask my-docs --interactive
|
||||||
|
|
||||||
|
# Ask a single question (non-interactive)
|
||||||
|
leann ask my-docs "Where are prompts configured?"
|
||||||
|
|
||||||
# List all your indexes
|
# List all your indexes
|
||||||
leann list
|
leann list
|
||||||
|
|
||||||
@@ -491,7 +558,8 @@ leann remove my-docs
|
|||||||
|
|
||||||
**Key CLI features:**
|
**Key CLI features:**
|
||||||
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
- Auto-detects document formats (PDF, TXT, MD, DOCX, PPTX + code files)
|
||||||
- Smart text chunking with overlap
|
- **🧠 AST-aware chunking** for Python, Java, C#, TypeScript files
|
||||||
|
- Smart text chunking with overlap for all other content
|
||||||
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
- Multiple LLM providers (Ollama, OpenAI, HuggingFace)
|
||||||
- Organized index storage in `.leann/indexes/` (project-local)
|
- Organized index storage in `.leann/indexes/` (project-local)
|
||||||
- Support for advanced search parameters
|
- Support for advanced search parameters
|
||||||
@@ -564,6 +632,46 @@ Options:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## 🚀 Advanced Features
|
||||||
|
|
||||||
|
### 🎯 Metadata Filtering
|
||||||
|
|
||||||
|
LEANN supports a simple metadata filtering system to enable sophisticated use cases like document filtering by date/type, code search by file extension, and content management based on custom criteria.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Add metadata during indexing
|
||||||
|
builder.add_text(
|
||||||
|
"def authenticate_user(token): ...",
|
||||||
|
metadata={"file_extension": ".py", "lines_of_code": 25}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Supported operators**: `==`, `!=`, `<`, `<=`, `>`, `>=`, `in`, `not_in`, `contains`, `starts_with`, `ends_with`, `is_true`, `is_false`
|
||||||
|
|
||||||
|
📖 **[Complete Metadata filtering guide →](docs/metadata_filtering.md)**
|
||||||
|
|
||||||
|
### 🔍 Grep Search
|
||||||
|
|
||||||
|
For exact text matching instead of semantic search, use the `use_grep` parameter:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("banana‑crocodile", use_grep=True, top_k=1)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Use cases**: Finding specific code patterns, error messages, function names, or exact phrases where semantic similarity isn't needed.
|
||||||
|
|
||||||
|
📖 **[Complete grep search guide →](docs/grep_search.md)**
|
||||||
|
|
||||||
## 🏗️ Architecture & How It Works
|
## 🏗️ Architecture & How It Works
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
@@ -603,6 +711,7 @@ Options:
|
|||||||
```bash
|
```bash
|
||||||
uv pip install -e ".[dev]" # Install dev dependencies
|
uv pip install -e ".[dev]" # Install dev dependencies
|
||||||
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
python benchmarks/run_evaluation.py # Will auto-download evaluation data and run benchmarks
|
||||||
|
python benchmarks/run_evaluation.py benchmarks/data/indices/rpj_wiki/rpj_wiki --num-queries 2000 # After downloading data, you can run the benchmark with our biggest index
|
||||||
```
|
```
|
||||||
|
|
||||||
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
The evaluation script downloads data automatically on first run. The last three results were tested with partial personal data, and you can reproduce them with your own data!
|
||||||
@@ -642,6 +751,9 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|||||||
|
|
||||||
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
Core Contributors: [Yichuan Wang](https://yichuan-w.github.io/) & [Zhifei Li](https://github.com/andylizf).
|
||||||
|
|
||||||
|
Active Contributors: [Gabriel Dehan](https://github.com/gabriel-dehan)
|
||||||
|
|
||||||
|
|
||||||
We welcome more contributors! Feel free to open issues or submit PRs.
|
We welcome more contributors! Feel free to open issues or submit PRs.
|
||||||
|
|
||||||
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
This work is done at [**Berkeley Sky Computing Lab**](https://sky.cs.berkeley.edu/).
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from typing import Any
|
|||||||
import dotenv
|
import dotenv
|
||||||
from leann.api import LeannBuilder, LeannChat
|
from leann.api import LeannBuilder, LeannChat
|
||||||
from leann.registry import register_project_directory
|
from leann.registry import register_project_directory
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
@@ -79,6 +79,24 @@ class BaseRAGExample(ABC):
|
|||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||||
)
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# LLM parameters
|
# LLM parameters
|
||||||
llm_group = parser.add_argument_group("LLM Parameters")
|
llm_group = parser.add_argument_group("LLM Parameters")
|
||||||
@@ -98,8 +116,8 @@ class BaseRAGExample(ABC):
|
|||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--llm-host",
|
"--llm-host",
|
||||||
type=str,
|
type=str,
|
||||||
default="http://localhost:11434",
|
default=None,
|
||||||
help="Host for Ollama API (default: http://localhost:11434)",
|
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
)
|
)
|
||||||
llm_group.add_argument(
|
llm_group.add_argument(
|
||||||
"--thinking-budget",
|
"--thinking-budget",
|
||||||
@@ -108,6 +126,50 @@ class BaseRAGExample(ABC):
|
|||||||
default=None,
|
default=None,
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
)
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# AST Chunking parameters
|
||||||
|
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--use-ast-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum characters per AST chunk (default: 512)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-overlap",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Overlap between AST chunks (default: 64)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--code-file-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-fallback-traditional",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
# Search parameters
|
# Search parameters
|
||||||
search_group = parser.add_argument_group("Search Parameters")
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
@@ -174,9 +236,13 @@ class BaseRAGExample(ABC):
|
|||||||
|
|
||||||
if args.llm == "openai":
|
if args.llm == "openai":
|
||||||
config["model"] = args.llm_model or "gpt-4o"
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||||
|
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||||
|
if resolved_key:
|
||||||
|
config["api_key"] = resolved_key
|
||||||
elif args.llm == "ollama":
|
elif args.llm == "ollama":
|
||||||
config["model"] = args.llm_model or "llama3.2:1b"
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
config["host"] = args.llm_host
|
config["host"] = resolve_ollama_host(args.llm_host)
|
||||||
elif args.llm == "hf":
|
elif args.llm == "hf":
|
||||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
elif args.llm == "simulated":
|
elif args.llm == "simulated":
|
||||||
@@ -192,10 +258,20 @@ class BaseRAGExample(ABC):
|
|||||||
print(f"\n[Building Index] Creating {self.name} index...")
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
print(f"Total text chunks: {len(texts)}")
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend_name,
|
backend_name=args.backend_name,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.build_complexity,
|
complexity=args.build_complexity,
|
||||||
is_compact=not args.no_compact,
|
is_compact=not args.no_compact,
|
||||||
@@ -268,7 +344,6 @@ class BaseRAGExample(ABC):
|
|||||||
chat = LeannChat(
|
chat = LeannChat(
|
||||||
index_path,
|
index_path,
|
||||||
llm_config=self.get_llm_config(args),
|
llm_config=self.get_llm_config(args),
|
||||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
|
||||||
complexity=args.search_complexity,
|
complexity=args.search_complexity,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -310,21 +385,3 @@ class BaseRAGExample(ABC):
|
|||||||
await self.run_single_query(args, index_path, args.query)
|
await self.run_single_query(args, index_path, args.query)
|
||||||
else:
|
else:
|
||||||
await self.run_interactive_chat(args, index_path)
|
await self.run_interactive_chat(args, index_path)
|
||||||
|
|
||||||
|
|
||||||
def create_text_chunks(documents, chunk_size=256, chunk_overlap=25) -> list[str]:
|
|
||||||
"""Helper function to create text chunks from documents."""
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
separator=" ",
|
|
||||||
paragraph_separator="\n\n",
|
|
||||||
)
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
if nodes:
|
|
||||||
all_texts.extend(node.get_content() for node in nodes)
|
|
||||||
|
|
||||||
return all_texts
|
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
from .history_data.history import ChromeHistoryReader
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|||||||
44
apps/chunking/__init__.py
Normal file
44
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Unified chunking utilities facade.
|
||||||
|
|
||||||
|
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||||
|
that both repo apps (importing `chunking`) and installed wheels share one
|
||||||
|
single implementation. When running from the repo without installation, it
|
||||||
|
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||||
|
if leann_src.exists():
|
||||||
|
sys.path.insert(0, str(leann_src))
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CODE_EXTENSIONS",
|
||||||
|
"create_ast_chunks",
|
||||||
|
"create_text_chunks",
|
||||||
|
"create_traditional_chunks",
|
||||||
|
"detect_code_files",
|
||||||
|
"get_language_from_extension",
|
||||||
|
]
|
||||||
211
apps/code_rag.py
Normal file
211
apps/code_rag.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||||
|
Specialized for code repositories with automatic language detection and
|
||||||
|
optimized chunking parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRAG(BaseRAGExample):
|
||||||
|
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Code",
|
||||||
|
description="Process and query code repositories with AST-aware chunking",
|
||||||
|
default_index_name="code_index",
|
||||||
|
)
|
||||||
|
# Override defaults for code-specific usage
|
||||||
|
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||||
|
self.max_items_default = -1 # Process all code files by default
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add code-specific arguments."""
|
||||||
|
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||||
|
|
||||||
|
code_group.add_argument(
|
||||||
|
"--repo-dir",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Code repository directory to index (default: current directory)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=list(CODE_EXTENSIONS.keys()),
|
||||||
|
help="File extensions to include (default: supported code extensions)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--exclude-dirs",
|
||||||
|
nargs="+",
|
||||||
|
default=[
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
".venv",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"target",
|
||||||
|
],
|
||||||
|
help="Directories to exclude from indexing",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--max-file-size",
|
||||||
|
type=int,
|
||||||
|
default=1000000, # 1MB
|
||||||
|
help="Maximum file size in bytes to process (default: 1MB)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-comments",
|
||||||
|
action="store_true",
|
||||||
|
help="Include comments in chunking (useful for documentation)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--preserve-imports",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Try to preserve import statements in chunks (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load code files and convert to AST-aware chunks."""
|
||||||
|
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||||
|
print(f"📁 Including extensions: {args.include_extensions}")
|
||||||
|
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||||
|
|
||||||
|
# Check if repository directory exists
|
||||||
|
repo_path = Path(args.repo_dir)
|
||||||
|
if not repo_path.exists():
|
||||||
|
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||||
|
|
||||||
|
# Load code files with filtering
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"required_exts": args.include_extensions,
|
||||||
|
"exclude_hidden": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create exclusion filter
|
||||||
|
def file_filter(file_path: str) -> bool:
|
||||||
|
"""Filter out unwanted files and directories."""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# Check file size
|
||||||
|
try:
|
||||||
|
if path.stat().st_size > args.max_file_size:
|
||||||
|
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if in excluded directory
|
||||||
|
for exclude_dir in args.exclude_dirs:
|
||||||
|
if exclude_dir in path.parts:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load documents with file filtering
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.repo_dir,
|
||||||
|
file_extractor=None, # Use default extractors
|
||||||
|
**reader_kwargs,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
|
# Apply custom filtering
|
||||||
|
filtered_docs = []
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_filter(file_path):
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
documents = filtered_docs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error loading code files: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(
|
||||||
|
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"✅ Loaded {len(documents)} code files")
|
||||||
|
|
||||||
|
# Show breakdown by language/extension
|
||||||
|
ext_counts = {}
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||||
|
|
||||||
|
print("📊 Files by extension:")
|
||||||
|
for ext, count in sorted(ext_counts.items()):
|
||||||
|
print(f" {ext}: {count} files")
|
||||||
|
|
||||||
|
# Use AST-aware chunking by default for code
|
||||||
|
print(
|
||||||
|
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=256, # Fallback for non-code files
|
||||||
|
chunk_overlap=64,
|
||||||
|
use_ast_chunking=True, # Always use AST for code RAG
|
||||||
|
ast_chunk_size=args.ast_chunk_size,
|
||||||
|
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||||
|
code_file_extensions=args.include_extensions,
|
||||||
|
ast_fallback_traditional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for code RAG
|
||||||
|
print("\n💻 Code RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'How does the embedding computation work?'")
|
||||||
|
print("- 'What are the main classes in this codebase?'")
|
||||||
|
print("- 'Show me the search implementation'")
|
||||||
|
print("- 'How is error handling implemented?'")
|
||||||
|
print("- 'What design patterns are used?'")
|
||||||
|
print("- 'Explain the chunking logic'")
|
||||||
|
print("\n🚀 Features:")
|
||||||
|
print("- ✅ AST-aware chunking preserves code structure")
|
||||||
|
print("- ✅ Automatic language detection")
|
||||||
|
print("- ✅ Smart filtering of large files and common excludes")
|
||||||
|
print("- ✅ Optimized for code understanding")
|
||||||
|
print("\nUsage examples:")
|
||||||
|
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||||
|
print(
|
||||||
|
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||||
|
)
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = CodeRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
@@ -9,7 +9,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
@@ -44,6 +45,11 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
doc_group.add_argument(
|
doc_group.add_argument(
|
||||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
)
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--enable-code-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files in the data directory",
|
||||||
|
)
|
||||||
|
|
||||||
async def load_data(self, args) -> list[str]:
|
async def load_data(self, args) -> list[str]:
|
||||||
"""Load documents and convert to text chunks."""
|
"""Load documents and convert to text chunks."""
|
||||||
@@ -76,9 +82,22 @@ class DocumentRAG(BaseRAGExample):
|
|||||||
|
|
||||||
print(f"Loaded {len(documents)} documents")
|
print(f"Loaded {len(documents)} documents")
|
||||||
|
|
||||||
# Convert to text chunks
|
# Determine chunking strategy
|
||||||
|
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||||
|
|
||||||
|
if use_ast:
|
||||||
|
print("Using AST-aware chunking for code files")
|
||||||
|
|
||||||
|
# Convert to text chunks with optional AST support
|
||||||
all_texts = create_text_chunks(
|
all_texts = create_text_chunks(
|
||||||
documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
documents,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
use_ast_chunking=use_ast,
|
||||||
|
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||||
|
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||||
|
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||||
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply max_items limit if specified
|
# Apply max_items limit if specified
|
||||||
@@ -102,6 +121,10 @@ if __name__ == "__main__":
|
|||||||
print(
|
print(
|
||||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||||
)
|
)
|
||||||
|
print("\n🚀 NEW: Code-aware chunking available!")
|
||||||
|
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||||
|
print("- Supports Python, Java, C#, TypeScript files")
|
||||||
|
print("- Better semantic understanding of code structure")
|
||||||
print("\nOr run without --query for interactive mode\n")
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
rag = DocumentRAG()
|
rag = DocumentRAG()
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from pathlib import Path
|
|||||||
# Add parent directory to path for imports
|
# Add parent directory to path for imports
|
||||||
sys.path.insert(0, str(Path(__file__).parent))
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
from base_rag_example import BaseRAGExample, create_text_chunks
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
from .email_data.LEANN_email_reader import EmlxReader
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
|
|||||||
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
82
benchmarks/data/.gitattributes
vendored
82
benchmarks/data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
|||||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.model filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
||||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - uncompressed
|
|
||||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - compressed
|
|
||||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - uncompressed
|
|
||||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.png filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - compressed
|
|
||||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Video files - compressed
|
|
||||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
44
benchmarks/data/README.md
Executable file
44
benchmarks/data/README.md
Executable file
@@ -0,0 +1,44 @@
|
|||||||
|
---
|
||||||
|
license: mit
|
||||||
|
---
|
||||||
|
|
||||||
|
# LEANN-RAG Evaluation Data
|
||||||
|
|
||||||
|
This repository contains the necessary data to run the recall evaluation scripts for the [LEANN-RAG](https://huggingface.co/LEANN-RAG) project.
|
||||||
|
|
||||||
|
## Dataset Components
|
||||||
|
|
||||||
|
This dataset is structured into three main parts:
|
||||||
|
|
||||||
|
1. **Pre-built LEANN Indices**:
|
||||||
|
* `dpr/`: A pre-built index for the DPR dataset.
|
||||||
|
* `rpj_wiki/`: A pre-built index for the RPJ-Wiki dataset.
|
||||||
|
These indices were created using the `leann-core` library and are required by the `LeannSearcher`.
|
||||||
|
|
||||||
|
2. **Ground Truth Data**:
|
||||||
|
* `ground_truth/`: Contains the ground truth files (`flat_results_nq_k3.json`) for both the DPR and RPJ-Wiki datasets. These files map queries to the original passage IDs from the Natural Questions benchmark, evaluated using the Contriever model.
|
||||||
|
|
||||||
|
3. **Queries**:
|
||||||
|
* `queries/`: Contains the `nq_open.jsonl` file with the Natural Questions queries used for the evaluation.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
To use this data, you can download it locally using the `huggingface-hub` library. First, install the library:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install huggingface-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can download the entire dataset to a local directory (e.g., `data/`) with the following Python script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir="data"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will download all the necessary files into a local `data` folder, preserving the repository structure. The evaluation scripts in the main [LEANN-RAG Space](https://huggingface.co/LEANN-RAG) are configured to work with this data structure.
|
||||||
@@ -12,7 +12,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
@@ -197,6 +197,25 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Batch size for HNSW batched search (0 disables batching)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||||
|
default="ollama",
|
||||||
|
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default="qwen3:1.7b",
|
||||||
|
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
@@ -318,9 +337,24 @@ def main():
|
|||||||
|
|
||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(queries[i], top_k=args.top_k, ef=args.ef_search)
|
new_results = searcher.search(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||||
|
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||||
|
answer = chat.ask(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
print(f"Answer: {answer}")
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ except ImportError:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever"
|
model_path: str = "facebook/contriever-msmarco"
|
||||||
batch_sizes: list[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
@@ -34,7 +34,7 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
@@ -179,10 +179,16 @@ class Benchmark:
|
|||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
# print shape of input_ids and attention_mask
|
||||||
|
print(f"input_ids shape: {input_ids.shape}")
|
||||||
|
print(f"attention_mask shape: {attention_mask.shape}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|||||||
143
docs/ast_chunking_guide.md
Normal file
143
docs/ast_chunking_guide.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# AST-Aware Code chunking guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Enable AST chunking for mixed content (code + docs)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
||||||
|
|
||||||
|
# Specialized code repository indexing
|
||||||
|
python -m apps.code_rag --repo-dir ./my_codebase
|
||||||
|
|
||||||
|
# Global CLI with AST support
|
||||||
|
leann build my-code-index --docs ./src --use-ast-chunking
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install LEANN with AST chunking support
|
||||||
|
uv pip install -e "."
|
||||||
|
```
|
||||||
|
|
||||||
|
#### For normal users (PyPI install)
|
||||||
|
- Use `pip install leann` or `uv pip install leann`.
|
||||||
|
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||||
|
|
||||||
|
#### For developers (from source, editable)
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||||
|
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||||
|
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### When to Use AST Chunking
|
||||||
|
|
||||||
|
✅ **Recommended for:**
|
||||||
|
- Code repositories with multiple languages
|
||||||
|
- Mixed documentation and code content
|
||||||
|
- Complex codebases with deep function/class hierarchies
|
||||||
|
- When working with Claude Code for code assistance
|
||||||
|
|
||||||
|
❌ **Not recommended for:**
|
||||||
|
- Pure text documents
|
||||||
|
- Very large files (>1MB)
|
||||||
|
- Languages not supported by tree-sitter
|
||||||
|
|
||||||
|
### Optimal Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Recommended settings for most codebases
|
||||||
|
python -m apps.code_rag \
|
||||||
|
--repo-dir ./src \
|
||||||
|
--ast-chunk-size 768 \
|
||||||
|
--ast-chunk-overlap 96 \
|
||||||
|
--exclude-dirs .git __pycache__ node_modules build dist
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Languages
|
||||||
|
|
||||||
|
| Extension | Language | Status |
|
||||||
|
|-----------|----------|--------|
|
||||||
|
| `.py` | Python | ✅ Full support |
|
||||||
|
| `.java` | Java | ✅ Full support |
|
||||||
|
| `.cs` | C# | ✅ Full support |
|
||||||
|
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
||||||
|
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
||||||
|
|
||||||
|
## Integration Examples
|
||||||
|
|
||||||
|
### Document RAG with Code Support
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Enable code chunking in document RAG
|
||||||
|
python -m apps.document_rag \
|
||||||
|
--enable-code-chunking \
|
||||||
|
--data-dir ./project \
|
||||||
|
--query "How does authentication work in the codebase?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claude Code Integration
|
||||||
|
|
||||||
|
When using with Claude Code MCP server, AST chunking provides better context for:
|
||||||
|
- Code completion and suggestions
|
||||||
|
- Bug analysis and debugging
|
||||||
|
- Architecture understanding
|
||||||
|
- Refactoring assistance
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Fallback to Traditional Chunking**
|
||||||
|
- Normal behavior for unsupported languages
|
||||||
|
- Check logs for specific language support
|
||||||
|
|
||||||
|
2. **Performance with Large Files**
|
||||||
|
- Adjust `--max-file-size` parameter
|
||||||
|
- Use `--exclude-dirs` to skip unnecessary directories
|
||||||
|
|
||||||
|
3. **Quality Issues**
|
||||||
|
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
||||||
|
- Adjust overlap for better context preservation
|
||||||
|
|
||||||
|
### Debug Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LEANN_LOG_LEVEL=DEBUG
|
||||||
|
python -m apps.code_rag --repo-dir ./my_code
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration from Traditional Chunking
|
||||||
|
|
||||||
|
Existing workflows continue to work without changes. To enable AST chunking:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Before
|
||||||
|
python -m apps.document_rag --chunk-size 256
|
||||||
|
|
||||||
|
# After (maintains traditional chunking for non-code files)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
||||||
|
```
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
||||||
|
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
||||||
|
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
||||||
@@ -83,6 +83,81 @@ ollama pull nomic-embed-text
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## Local & Remote Inference Endpoints
|
||||||
|
|
||||||
|
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||||
|
|
||||||
|
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||||
|
|
||||||
|
### One-Time Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||||
|
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||||
|
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||||
|
|
||||||
|
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||||
|
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||||
|
```
|
||||||
|
|
||||||
|
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||||
|
|
||||||
|
### Passing Hosts Per Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index with a remote embedding server
|
||||||
|
leann build my-notes \
|
||||||
|
--docs ./notes \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||||
|
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||||
|
--embedding-api-key local-dev-key
|
||||||
|
|
||||||
|
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm openai \
|
||||||
|
--llm-model qwen3-8b \
|
||||||
|
--api-base http://localhost:1234/v1 \
|
||||||
|
--api-key local-dev-key
|
||||||
|
|
||||||
|
# Query an Ollama instance running on another box
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm ollama \
|
||||||
|
--llm-model qwen3:14b \
|
||||||
|
--host http://192.168.1.101:11434
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||||
|
|
||||||
|
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||||
|
- Configure router or cloud provider port forwarding.
|
||||||
|
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||||
|
|
||||||
|
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||||
|
|
||||||
|
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||||
|
|
||||||
|
### Python API Usage
|
||||||
|
|
||||||
|
You can pass the same configuration from Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://192.168.1.50:1234/v1",
|
||||||
|
"api_key": "local-dev-key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-notes", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
## Index Selection: Matching Your Scale
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
### HNSW (Hierarchical Navigable Small World)
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
## 🔥 Core Features
|
## 🔥 Core Features
|
||||||
|
|
||||||
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
||||||
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||||
|
|||||||
149
docs/grep_search.md
Normal file
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# LEANN Grep Search Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("your_index_path")
|
||||||
|
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:100]}...")
|
||||||
|
print("-" * 40)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Comparison: Semantic vs Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Semantic search - finds conceptually similar content
|
||||||
|
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
|
||||||
|
# Grep search - finds exact text matches
|
||||||
|
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to Use Grep Search
|
||||||
|
|
||||||
|
### Use Cases
|
||||||
|
|
||||||
|
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||||
|
- **Error Debugging**: Locating exact error messages or stack traces
|
||||||
|
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find function definitions
|
||||||
|
functions = searcher.search("def __init__", use_grep=True)
|
||||||
|
|
||||||
|
# Find import statements
|
||||||
|
imports = searcher.search("from sklearn import", use_grep=True)
|
||||||
|
|
||||||
|
# Find specific error types
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
|
||||||
|
# Find TODO comments
|
||||||
|
todos = searcher.search("TODO:", use_grep=True)
|
||||||
|
|
||||||
|
# Find configuration entries
|
||||||
|
configs = searcher.search("server_port=", use_grep=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||||
|
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||||
|
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||||
|
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||||
|
|
||||||
|
### Search Process
|
||||||
|
|
||||||
|
```
|
||||||
|
Query: "def train_model"
|
||||||
|
↓
|
||||||
|
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||||
|
↓
|
||||||
|
Parse matching JSON lines
|
||||||
|
↓
|
||||||
|
Calculate scores based on term frequency
|
||||||
|
↓
|
||||||
|
Return top_k results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scoring Algorithm
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Term frequency in document
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
```
|
||||||
|
|
||||||
|
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### Grep Command Not Found
|
||||||
|
```
|
||||||
|
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**: Install grep on your system:
|
||||||
|
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||||
|
- **macOS**: grep is pre-installed
|
||||||
|
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||||
|
|
||||||
|
#### No Results Found
|
||||||
|
```python
|
||||||
|
# Check if your query exists in the raw data
|
||||||
|
results = searcher.search("your_query", use_grep=True)
|
||||||
|
if not results:
|
||||||
|
print("No exact matches found. Try:")
|
||||||
|
print("1. Check spelling and case")
|
||||||
|
print("2. Use partial terms")
|
||||||
|
print("3. Switch to semantic search")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
Demonstrates grep search for exact text matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
def demonstrate_grep_search():
|
||||||
|
# Initialize searcher
|
||||||
|
searcher = LeannSearcher("my_index")
|
||||||
|
|
||||||
|
print("=== Function Search ===")
|
||||||
|
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||||
|
for i, result in enumerate(functions, 1):
|
||||||
|
print(f"{i}. Score: {result.score}")
|
||||||
|
print(f" Preview: {result.text[:60]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=== Error Search ===")
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||||
|
for result in errors:
|
||||||
|
print(f"Content: {result.text.strip()}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demonstrate_grep_search()
|
||||||
|
```
|
||||||
300
docs/metadata_filtering.md
Normal file
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
# LEANN Metadata Filtering Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Adding Metadata to Your Documents
|
||||||
|
|
||||||
|
When building your index, add metadata to each text chunk:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
|
||||||
|
# Add text with metadata
|
||||||
|
builder.add_text(
|
||||||
|
text="Chapter 1: Alice falls down the rabbit hole",
|
||||||
|
metadata={
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"themes": ["adventure", "curiosity"],
|
||||||
|
"word_count": 150
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.build_index("alice_in_wonderland_index")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Searching with Metadata Filters
|
||||||
|
|
||||||
|
Use the `metadata_filters` parameter in search calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="What happens to Alice?",
|
||||||
|
top_k=10,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||||
|
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Filter Syntax
|
||||||
|
|
||||||
|
### Basic Structure
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"field_name": {"operator": value},
|
||||||
|
"another_field": {"operator": value}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Operators
|
||||||
|
|
||||||
|
#### Comparison Operators
|
||||||
|
- `"=="`: Equal to
|
||||||
|
- `"!="`: Not equal to
|
||||||
|
- `"<"`: Less than
|
||||||
|
- `"<="`: Less than or equal
|
||||||
|
- `">"`: Greater than
|
||||||
|
- `">="`: Greater than or equal
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||||
|
{"page": {">": 100}} # Pages after 100
|
||||||
|
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||||
|
{"word_count": {"<": 500}} # Short passages
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Membership Operators
|
||||||
|
- `"in"`: Value is in list
|
||||||
|
- `"not_in"`: Value is not in list
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||||
|
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||||
|
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||||
|
```
|
||||||
|
|
||||||
|
#### String Operators
|
||||||
|
- `"contains"`: String contains substring
|
||||||
|
- `"starts_with"`: String starts with prefix
|
||||||
|
- `"ends_with"`: String ends with suffix
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||||
|
{"filename": {"ends_with": ".py"}} # Python files
|
||||||
|
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Boolean Operators
|
||||||
|
- `"is_true"`: Field is truthy
|
||||||
|
- `"is_false"`: Field is falsy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"is_published": {"is_true": True}} # Published content
|
||||||
|
{"is_draft": {"is_false": False}} # Not drafts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Operators on Same Field
|
||||||
|
|
||||||
|
You can apply multiple operators to the same field (AND logic):
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"word_count": {
|
||||||
|
">=": 100, # At least 100 words
|
||||||
|
"<=": 500 # At most 500 words
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compound Filters
|
||||||
|
|
||||||
|
Multiple fields are combined with AND logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"chapter": {"<=": 10}, # Up to chapter 10
|
||||||
|
"character": {"==": "Alice"}, # About Alice
|
||||||
|
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Case Examples
|
||||||
|
|
||||||
|
### 1. Spoiler-Free Book Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Reader has only read up to chapter 5
|
||||||
|
def search_spoiler_free(query, max_chapter):
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": max_chapter},
|
||||||
|
"spoiler_level": {"in": ["none", "low"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Document Management by Date
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find recent documents
|
||||||
|
recent_docs = searcher.search(
|
||||||
|
query="project updates",
|
||||||
|
metadata_filters={
|
||||||
|
"date": {">=": "2024-01-01"},
|
||||||
|
"document_type": {"==": "report"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Code Search by File Type
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search only Python files
|
||||||
|
python_code = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Content Filtering by Audience
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Age-appropriate content
|
||||||
|
family_content = searcher.search(
|
||||||
|
query="adventure stories",
|
||||||
|
metadata_filters={
|
||||||
|
"age_rating": {"in": ["G", "PG"]},
|
||||||
|
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Multi-Book Series Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search across first 3 books only
|
||||||
|
early_series = searcher.search(
|
||||||
|
query="character development",
|
||||||
|
metadata_filters={
|
||||||
|
"series": {"==": "Harry Potter"},
|
||||||
|
"book_number": {"<=": 3}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the Example
|
||||||
|
|
||||||
|
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Don't forget to set up the environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Run the spoiler-free book RAG example
|
||||||
|
uv run examples/spoiler_free_book_rag.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This example demonstrates:
|
||||||
|
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||||
|
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||||
|
- Different scenarios for readers at various points in the book
|
||||||
|
|
||||||
|
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||||
|
|
||||||
|
## Advanced Patterns
|
||||||
|
|
||||||
|
### Custom Chunking with metadata
|
||||||
|
|
||||||
|
```python
|
||||||
|
def chunk_book_with_metadata(book_text, book_info):
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||||
|
# Extract entities, themes, etc.
|
||||||
|
characters = extract_characters(chapter_text)
|
||||||
|
themes = classify_themes(chapter_text)
|
||||||
|
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||||
|
|
||||||
|
# Create chunks with rich metadata
|
||||||
|
for paragraph in split_paragraphs(chapter_text):
|
||||||
|
chunks.append({
|
||||||
|
"text": paragraph,
|
||||||
|
"metadata": {
|
||||||
|
"book_title": book_info["title"],
|
||||||
|
"chapter": chapter_num,
|
||||||
|
"characters": characters,
|
||||||
|
"themes": themes,
|
||||||
|
"spoiler_level": spoiler_level,
|
||||||
|
"word_count": len(paragraph.split()),
|
||||||
|
"reading_level": calculate_reading_level(paragraph)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Efficient Filtering Strategies
|
||||||
|
|
||||||
|
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||||
|
|
||||||
|
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||||
|
|
||||||
|
### Best Practices
|
||||||
|
|
||||||
|
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||||
|
|
||||||
|
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||||
|
|
||||||
|
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||||
|
|
||||||
|
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||||
|
|
||||||
|
### Adding Metadata to Existing Indices
|
||||||
|
|
||||||
|
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Read existing passages and add metadata
|
||||||
|
def add_metadata_to_existing_chunks(chunks):
|
||||||
|
for chunk in chunks:
|
||||||
|
# Extract or assign metadata based on content
|
||||||
|
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Rebuild index with metadata
|
||||||
|
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
for chunk in enhanced_chunks:
|
||||||
|
builder.add_text(chunk["text"], chunk["metadata"])
|
||||||
|
builder.build_index("enhanced_index")
|
||||||
|
```
|
||||||
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
404
examples/dynamic_update_no_recompute.py
Normal file
404
examples/dynamic_update_no_recompute.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
"""Dynamic HNSW update demo without compact storage.
|
||||||
|
|
||||||
|
This script reproduces the minimal scenario we used while debugging on-the-fly
|
||||||
|
recompute:
|
||||||
|
|
||||||
|
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
|
||||||
|
2. Print the top results with `recompute_embeddings=True`.
|
||||||
|
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
|
||||||
|
4. Run the same query again to show the newly inserted passages.
|
||||||
|
|
||||||
|
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
|
||||||
|
ZMQ activity)::
|
||||||
|
|
||||||
|
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
|
||||||
|
uv run -m examples.dynamic_update_no_recompute \
|
||||||
|
--index-path .leann/examples/leann-demo.leann
|
||||||
|
|
||||||
|
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
|
||||||
|
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
|
||||||
|
It issues the query "What's LEANN?" before and after the update to show how the
|
||||||
|
new passages become immediately searchable. The script uses the
|
||||||
|
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
|
||||||
|
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
|
||||||
|
freshly added passages are embedded locally just like the initial build.
|
||||||
|
|
||||||
|
To make storage comparisons easy, the script can also build a matching
|
||||||
|
``is_recompute=False`` baseline (enabled by default) and report the index size
|
||||||
|
delta after the update. Disable the baseline run with
|
||||||
|
``--skip-compare-no-recompute`` if you only need the recompute flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
DEFAULT_QUERY = "What's LEANN?"
|
||||||
|
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path]) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
return [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
|
||||||
|
searcher = LeannSearcher(str(index_path))
|
||||||
|
try:
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(title: str, results: Iterable) -> None:
|
||||||
|
print(f"\n=== {title} ===")
|
||||||
|
res_list = list(results)
|
||||||
|
print(f"results count: {len(res_list)}")
|
||||||
|
print("passages:")
|
||||||
|
if not res_list:
|
||||||
|
print(" (no passages returned)")
|
||||||
|
for res in res_list:
|
||||||
|
snippet = res.text.replace("\n", " ")[:120]
|
||||||
|
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def update_index(
|
||||||
|
index_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
) -> None:
|
||||||
|
updater = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
for offset, passage in enumerate(paragraphs, start=start_id):
|
||||||
|
updater.add_text(passage, metadata={"id": str(offset)})
|
||||||
|
updater.update_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
"""Remove leftover index artifacts for a clean rebuild."""
|
||||||
|
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def index_file_size(index_path: Path) -> int:
|
||||||
|
"""Return the size of the primary .index file for the given index path."""
|
||||||
|
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
return index_file.stat().st_size if index_file.exists() else 0
|
||||||
|
|
||||||
|
|
||||||
|
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
if not meta_path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(meta_path.read_text())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run_workflow(
|
||||||
|
*,
|
||||||
|
label: str,
|
||||||
|
index_path: Path,
|
||||||
|
initial_paragraphs: list[str],
|
||||||
|
update_paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
prefix = f"[{label}] " if label else ""
|
||||||
|
|
||||||
|
ensure_index_dir(index_path)
|
||||||
|
cleanup_index_files(index_path)
|
||||||
|
|
||||||
|
print(f"{prefix}Building initial index...")
|
||||||
|
build_initial_index(
|
||||||
|
index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_size = index_file_size(index_path)
|
||||||
|
before_results = run_search(
|
||||||
|
index_path,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{prefix}Updating index with additional passages...")
|
||||||
|
update_index(
|
||||||
|
index_path,
|
||||||
|
start_id=len(initial_paragraphs),
|
||||||
|
paragraphs=update_paragraphs,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
after_results = run_search(
|
||||||
|
index_path,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
updated_size = index_file_size(index_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"initial_size": initial_size,
|
||||||
|
"updated_size": updated_size,
|
||||||
|
"delta": updated_size - initial_size,
|
||||||
|
"before_results": before_results,
|
||||||
|
"after_results": after_results,
|
||||||
|
"metadata": load_metadata_snapshot(index_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
type=Path,
|
||||||
|
nargs="+",
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Initial document files (PDF/TXT) used to build the base index",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/examples/leann-demo.leann"),
|
||||||
|
help="Destination index path (default: .leann/examples/leann-demo.leann)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-count",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of chunks to use from the initial documents (default: 8)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
type=Path,
|
||||||
|
nargs="*",
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Additional documents to add during update (PDF/TXT)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-count",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of chunks to append from update documents (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-text",
|
||||||
|
type=str,
|
||||||
|
default=(
|
||||||
|
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
|
||||||
|
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
|
||||||
|
"on demand to keep disk usage minimal."
|
||||||
|
),
|
||||||
|
help="Fallback text to append if --update-files is omitted",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of results to show for each search (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_QUERY,
|
||||||
|
help="Query to run before/after the update",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--compare-no-recompute",
|
||||||
|
dest="compare_no_recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Also run a baseline with is_recompute=False and report its index growth.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-compare-no-recompute",
|
||||||
|
dest="compare_no_recompute",
|
||||||
|
action="store_false",
|
||||||
|
help="Skip building the no-recompute baseline.",
|
||||||
|
)
|
||||||
|
parser.set_defaults(compare_no_recompute=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_chunks = load_chunks_from_files(list(args.initial_files))
|
||||||
|
if not initial_chunks:
|
||||||
|
raise ValueError("No text chunks extracted from the initial files.")
|
||||||
|
|
||||||
|
initial = initial_chunks[: args.initial_count]
|
||||||
|
if not initial:
|
||||||
|
raise ValueError("Initial chunk set is empty after applying --initial-count.")
|
||||||
|
|
||||||
|
if args.update_files:
|
||||||
|
update_chunks = load_chunks_from_files(list(args.update_files))
|
||||||
|
if not update_chunks:
|
||||||
|
raise ValueError("No text chunks extracted from the update files.")
|
||||||
|
to_add = update_chunks[: args.update_count]
|
||||||
|
else:
|
||||||
|
if not args.update_text:
|
||||||
|
raise ValueError("Provide --update-files or --update-text for the update step.")
|
||||||
|
to_add = [args.update_text]
|
||||||
|
if not to_add:
|
||||||
|
raise ValueError("Update chunk set is empty after applying --update-count.")
|
||||||
|
|
||||||
|
recompute_stats = run_workflow(
|
||||||
|
label="recompute",
|
||||||
|
index_path=args.index_path,
|
||||||
|
initial_paragraphs=initial,
|
||||||
|
update_paragraphs=to_add,
|
||||||
|
model_name=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
is_recompute=True,
|
||||||
|
query=args.query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_results("initial search", recompute_stats["before_results"])
|
||||||
|
print_results("after update", recompute_stats["after_results"])
|
||||||
|
print(
|
||||||
|
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||||
|
f" (Δ {recompute_stats['delta']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recompute_stats["metadata"]:
|
||||||
|
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||||
|
print("[recompute] metadata snapshot:")
|
||||||
|
print(json.dumps(meta_view, indent=2))
|
||||||
|
|
||||||
|
if args.compare_no_recompute:
|
||||||
|
baseline_path = (
|
||||||
|
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
|
||||||
|
)
|
||||||
|
baseline_stats = run_workflow(
|
||||||
|
label="no-recompute",
|
||||||
|
index_path=baseline_path,
|
||||||
|
initial_paragraphs=initial,
|
||||||
|
update_paragraphs=to_add,
|
||||||
|
model_name=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
is_recompute=False,
|
||||||
|
query=args.query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
|
||||||
|
f" (Δ {baseline_stats['delta']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
after_texts = [res.text for res in recompute_stats["after_results"]]
|
||||||
|
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
|
||||||
|
if after_texts == baseline_after_texts:
|
||||||
|
print(
|
||||||
|
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[no-recompute] WARNING: search results differ from recompute baseline.")
|
||||||
|
|
||||||
|
if baseline_stats["metadata"]:
|
||||||
|
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||||
|
print("[no-recompute] metadata snapshot:")
|
||||||
|
print(json.dumps(meta_view, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
35
examples/grep_search_example.py
Normal file
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
|
||||||
|
Shows how to use grep-based text search instead of semantic search.
|
||||||
|
Useful when you need exact text matches rather than meaning-based results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Load your index
|
||||||
|
searcher = LeannSearcher("my-documents.leann")
|
||||||
|
|
||||||
|
# Regular semantic search
|
||||||
|
print("=== Semantic Search ===")
|
||||||
|
results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score:.3f}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Grep-based search for exact text matches
|
||||||
|
print("=== Grep Search ===")
|
||||||
|
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Find specific error messages
|
||||||
|
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||||
|
|
||||||
|
# Search for function definitions
|
||||||
|
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||||
|
print(f"Found {len(func_results)} class definitions")
|
||||||
250
examples/spoiler_free_book_rag.py
Normal file
250
examples/spoiler_free_book_rag.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
||||||
|
|
||||||
|
This example demonstrates how to use LEANN's metadata filtering to create
|
||||||
|
a spoiler-free book RAG system where users can search for information
|
||||||
|
up to a specific chapter they've read.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python spoiler_free_book_rag.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Add LEANN to path (adjust path as needed)
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Create sample book chunks with metadata for demonstration.
|
||||||
|
|
||||||
|
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
||||||
|
and extract chapter boundaries, character mentions, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_title: Title of the book
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chunk dictionaries with text and metadata
|
||||||
|
"""
|
||||||
|
# Sample book chunks with metadata
|
||||||
|
# In practice, you'd use proper text processing libraries
|
||||||
|
|
||||||
|
sample_chunks = [
|
||||||
|
{
|
||||||
|
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 1,
|
||||||
|
"characters": ["Alice", "Sister"],
|
||||||
|
"themes": ["boredom", "curiosity"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 2,
|
||||||
|
"characters": ["Alice", "White Rabbit"],
|
||||||
|
"themes": ["decision", "surprise", "magic"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 2,
|
||||||
|
"page": 15,
|
||||||
|
"characters": ["Alice"],
|
||||||
|
"themes": ["falling", "wonder", "transformation"],
|
||||||
|
"location": "rabbit hole",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 6,
|
||||||
|
"page": 85,
|
||||||
|
"characters": ["Alice", "Cheshire Cat"],
|
||||||
|
"themes": ["madness", "philosophy", "identity"],
|
||||||
|
"location": "Duchess's house",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 8,
|
||||||
|
"page": 120,
|
||||||
|
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
||||||
|
"themes": ["justice", "absurdity", "authority"],
|
||||||
|
"location": "Queen's court",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 12,
|
||||||
|
"page": 180,
|
||||||
|
"characters": ["Alice", "Sister", "Rabbit"],
|
||||||
|
"themes": ["revelation", "reality", "growth"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return sample_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a LEANN index with book chunks that include spoiler metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_chunks: List of book chunks with metadata
|
||||||
|
index_name: Name for the index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the built index
|
||||||
|
"""
|
||||||
|
print(f"📚 Building spoiler-free book index: {index_name}")
|
||||||
|
|
||||||
|
# Initialize LEANN builder
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add each chunk with its metadata
|
||||||
|
for chunk in book_chunks:
|
||||||
|
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
index_path = f"{index_name}_book_index"
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
print(f"✅ Index built successfully: {index_path}")
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def spoiler_free_search(
|
||||||
|
index_path: str,
|
||||||
|
query: str,
|
||||||
|
max_chapter: int,
|
||||||
|
character_filter: Optional[list[str]] = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Perform a spoiler-free search on the book index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: Search query
|
||||||
|
max_chapter: Maximum chapter number to include
|
||||||
|
character_filter: Optional list of characters to focus on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results safe for the reader
|
||||||
|
"""
|
||||||
|
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
metadata_filters = {"chapter": {"<=": max_chapter}}
|
||||||
|
|
||||||
|
if character_filter:
|
||||||
|
metadata_filters["characters"] = {"contains": character_filter[0]}
|
||||||
|
|
||||||
|
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def demo_spoiler_free_rag():
|
||||||
|
"""
|
||||||
|
Demonstrate the spoiler-free book RAG system.
|
||||||
|
"""
|
||||||
|
print("🎭 Spoiler-Free Book RAG Demo")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
# Step 1: Prepare book data
|
||||||
|
book_title = "Alice's Adventures in Wonderland"
|
||||||
|
book_chunks = chunk_book_with_metadata(book_title)
|
||||||
|
|
||||||
|
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
||||||
|
|
||||||
|
# Step 2: Build the index (in practice, this would be done once)
|
||||||
|
try:
|
||||||
|
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
||||||
|
print(
|
||||||
|
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Demonstrate various spoiler-free searches
|
||||||
|
search_scenarios = [
|
||||||
|
{
|
||||||
|
"description": "Reader who has only read Chapter 1",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read up to Chapter 5",
|
||||||
|
"query": "Tell me about Alice's adventures",
|
||||||
|
"max_chapter": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read most of the book",
|
||||||
|
"query": "What does the Cheshire Cat represent?",
|
||||||
|
"max_chapter": 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read the whole book",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 12,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for scenario in search_scenarios:
|
||||||
|
print(f"\n📚 Scenario: {scenario['description']}")
|
||||||
|
print(f" Query: {scenario['query']}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = spoiler_free_search(
|
||||||
|
index_path=index_path,
|
||||||
|
query=scenario["query"],
|
||||||
|
max_chapter=scenario["max_chapter"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📄 Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results[:3], 1): # Show top 3
|
||||||
|
chapter = result.metadata.get("chapter", "?")
|
||||||
|
location = result.metadata.get("location", "?")
|
||||||
|
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Search failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("📚 LEANN Spoiler-Free Book RAG Example")
|
||||||
|
print("=====================================")
|
||||||
|
|
||||||
|
try:
|
||||||
|
demo_spoiler_free_rag()
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error running demo: {e}")
|
||||||
28
llms.txt
Normal file
28
llms.txt
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# llms.txt — LEANN MCP and Agent Integration
|
||||||
|
product: LEANN
|
||||||
|
homepage: https://github.com/yichuan-w/LEANN
|
||||||
|
contact: https://github.com/yichuan-w/LEANN/issues
|
||||||
|
|
||||||
|
# Installation
|
||||||
|
install: uv tool install leann-core --with leann
|
||||||
|
|
||||||
|
# MCP Server Entry Point
|
||||||
|
mcp.server: leann_mcp
|
||||||
|
mcp.protocol_version: 2024-11-05
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
mcp.tools: leann_list, leann_search
|
||||||
|
|
||||||
|
mcp.tool.leann_list.description: List available LEANN indexes
|
||||||
|
mcp.tool.leann_list.input: {}
|
||||||
|
|
||||||
|
mcp.tool.leann_search.description: Semantic search across a named LEANN index
|
||||||
|
mcp.tool.leann_search.input.index_name: string, required
|
||||||
|
mcp.tool.leann_search.input.query: string, required
|
||||||
|
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
|
||||||
|
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
note: Build indexes with `leann build <name> --docs <files...>` before searching.
|
||||||
|
example.add: claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
example.verify: claude mcp list | cat
|
||||||
1
packages/astchunk-leann
Submodule
1
packages/astchunk-leann
Submodule
Submodule packages/astchunk-leann added at ad9afa07b9
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import zmq
|
import zmq
|
||||||
@@ -32,6 +32,16 @@ if not logger.handlers:
|
|||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||||
|
try:
|
||||||
|
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||||
|
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||||
|
PROVIDER_OPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def create_diskann_embedding_server(
|
def create_diskann_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
zmq_port: int = 5555,
|
zmq_port: int = 5555,
|
||||||
@@ -83,9 +93,7 @@ def create_diskann_embedding_server(
|
|||||||
|
|
||||||
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
logger.info(f"Loading PassageManager with metadata_file_path: {passages_file}")
|
||||||
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
passages = PassageManager(meta["passage_sources"], metadata_file_path=passages_file)
|
||||||
logger.info(
|
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Import protobuf after ensuring the path is correct
|
# Import protobuf after ensuring the path is correct
|
||||||
try:
|
try:
|
||||||
@@ -183,7 +191,12 @@ def create_diskann_embedding_server(
|
|||||||
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
logger.debug(f"Text lengths: {[len(t) for t in texts[:5]]}") # Show first 5
|
||||||
|
|
||||||
# Process embeddings using unified computation
|
# Process embeddings using unified computation
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
@@ -298,7 +311,12 @@ def create_diskann_embedding_server(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Process the request
|
# Process the request
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
logger.info(f"Computed embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy"]
|
requires = ["scikit-build-core>=0.10", "pybind11>=2.12.0", "numpy", "cmake>=3.30"]
|
||||||
build-backend = "scikit_build_core.build"
|
build-backend = "scikit_build_core.build"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-diskann"
|
name = "leann-backend-diskann"
|
||||||
version = "0.3.0"
|
version = "0.3.4"
|
||||||
dependencies = ["leann-core==0.3.0", "numpy", "protobuf>=3.19.0"]
|
dependencies = ["leann-core==0.3.4", "numpy", "protobuf>=3.19.0"]
|
||||||
|
|
||||||
[tool.scikit-build]
|
[tool.scikit-build]
|
||||||
# Key: simplified CMake path
|
# Key: simplified CMake path
|
||||||
|
|||||||
Submodule packages/leann-backend-diskann/third_party/DiskANN updated: 04048bb302...19f9603c72
@@ -49,9 +49,28 @@ set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
|||||||
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_C_API OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
|
||||||
# Disable additional SIMD versions to speed up compilation
|
# Disable x86-specific SIMD optimizations (important for ARM64 compatibility)
|
||||||
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX2 OFF CACHE BOOL "" FORCE)
|
||||||
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
set(FAISS_ENABLE_AVX512 OFF CACHE BOOL "" FORCE)
|
||||||
|
set(FAISS_ENABLE_SSE4_1 OFF CACHE BOOL "" FORCE)
|
||||||
|
|
||||||
|
# ARM64-specific configuration
|
||||||
|
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||||
|
message(STATUS "Configuring Faiss for ARM64 architecture")
|
||||||
|
|
||||||
|
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
||||||
|
# Use SVE optimization level for ARM64 Linux (as seen in Faiss conda build)
|
||||||
|
set(FAISS_OPT_LEVEL "sve" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'sve' for ARM64 Linux")
|
||||||
|
else()
|
||||||
|
# Use generic optimization for other ARM64 platforms (like macOS)
|
||||||
|
set(FAISS_OPT_LEVEL "generic" CACHE STRING "" FORCE)
|
||||||
|
message(STATUS "Setting FAISS_OPT_LEVEL to 'generic' for ARM64 ${CMAKE_SYSTEM_NAME}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# ARM64 compatibility: Faiss submodule has been modified to fix x86 header inclusion
|
||||||
|
message(STATUS "Using ARM64-compatible Faiss submodule")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Additional optimization options from INSTALL.md
|
# Additional optimization options from INSTALL.md
|
||||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE)
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import os
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -237,6 +239,288 @@ def write_compact_format(
|
|||||||
f_out.write(storage_data)
|
f_out.write(storage_data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HNSWComponents:
|
||||||
|
original_hnsw_data: dict[str, Any]
|
||||||
|
assign_probas_np: np.ndarray
|
||||||
|
cum_nneighbor_per_level_np: np.ndarray
|
||||||
|
levels_np: np.ndarray
|
||||||
|
is_compact: bool
|
||||||
|
compact_level_ptr: Optional[np.ndarray] = None
|
||||||
|
compact_node_offsets_np: Optional[np.ndarray] = None
|
||||||
|
compact_neighbors_data: Optional[list[int]] = None
|
||||||
|
offsets_np: Optional[np.ndarray] = None
|
||||||
|
neighbors_np: Optional[np.ndarray] = None
|
||||||
|
storage_fourcc: int = NULL_INDEX_FOURCC
|
||||||
|
storage_data: bytes = b""
|
||||||
|
|
||||||
|
|
||||||
|
def _read_hnsw_structure(f) -> HNSWComponents:
|
||||||
|
original_hnsw_data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
hnsw_index_fourcc = read_struct(f, "<I")
|
||||||
|
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected HNSW FourCC: {hnsw_index_fourcc:08x}. Expected one of {EXPECTED_HNSW_FOURCCS}."
|
||||||
|
)
|
||||||
|
|
||||||
|
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||||
|
original_hnsw_data["d"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["ntotal"] = read_struct(f, "<q")
|
||||||
|
original_hnsw_data["dummy1"] = read_struct(f, "<q")
|
||||||
|
original_hnsw_data["dummy2"] = read_struct(f, "<q")
|
||||||
|
original_hnsw_data["is_trained"] = read_struct(f, "?")
|
||||||
|
original_hnsw_data["metric_type"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["metric_arg"] = 0.0
|
||||||
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
|
original_hnsw_data["metric_arg"] = read_struct(f, "<f")
|
||||||
|
|
||||||
|
assign_probas_np = read_numpy_vector(f, np.float64, "d")
|
||||||
|
cum_nneighbor_per_level_np = read_numpy_vector(f, np.int32, "i")
|
||||||
|
levels_np = read_numpy_vector(f, np.int32, "i")
|
||||||
|
|
||||||
|
ntotal = len(levels_np)
|
||||||
|
if ntotal != original_hnsw_data["ntotal"]:
|
||||||
|
original_hnsw_data["ntotal"] = ntotal
|
||||||
|
|
||||||
|
pos_before_compact = f.tell()
|
||||||
|
is_compact_flag = None
|
||||||
|
try:
|
||||||
|
is_compact_flag = read_struct(f, "<?")
|
||||||
|
except EOFError:
|
||||||
|
is_compact_flag = None
|
||||||
|
|
||||||
|
if is_compact_flag:
|
||||||
|
compact_level_ptr = read_numpy_vector(f, np.uint64, "Q")
|
||||||
|
compact_node_offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||||
|
|
||||||
|
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||||
|
|
||||||
|
storage_fourcc = read_struct(f, "<I")
|
||||||
|
compact_neighbors_data_np = read_numpy_vector(f, np.int32, "i")
|
||||||
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
|
storage_data = f.read()
|
||||||
|
|
||||||
|
return HNSWComponents(
|
||||||
|
original_hnsw_data=original_hnsw_data,
|
||||||
|
assign_probas_np=assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||||
|
levels_np=levels_np,
|
||||||
|
is_compact=True,
|
||||||
|
compact_level_ptr=compact_level_ptr,
|
||||||
|
compact_node_offsets_np=compact_node_offsets_np,
|
||||||
|
compact_neighbors_data=compact_neighbors_data,
|
||||||
|
storage_fourcc=storage_fourcc,
|
||||||
|
storage_data=storage_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-compact case
|
||||||
|
f.seek(pos_before_compact)
|
||||||
|
|
||||||
|
pos_before_probe = f.tell()
|
||||||
|
try:
|
||||||
|
suspected_flag = read_struct(f, "<B")
|
||||||
|
if suspected_flag != 0x00:
|
||||||
|
f.seek(pos_before_probe)
|
||||||
|
except EOFError:
|
||||||
|
f.seek(pos_before_probe)
|
||||||
|
|
||||||
|
offsets_np = read_numpy_vector(f, np.uint64, "Q")
|
||||||
|
neighbors_np = read_numpy_vector(f, np.int32, "i")
|
||||||
|
|
||||||
|
original_hnsw_data["entry_point"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["max_level"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["efConstruction"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["efSearch"] = read_struct(f, "<i")
|
||||||
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f, "<i")
|
||||||
|
|
||||||
|
storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
storage_data = b""
|
||||||
|
try:
|
||||||
|
storage_fourcc = read_struct(f, "<I")
|
||||||
|
storage_data = f.read()
|
||||||
|
except EOFError:
|
||||||
|
storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
|
||||||
|
return HNSWComponents(
|
||||||
|
original_hnsw_data=original_hnsw_data,
|
||||||
|
assign_probas_np=assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np=cum_nneighbor_per_level_np,
|
||||||
|
levels_np=levels_np,
|
||||||
|
is_compact=False,
|
||||||
|
offsets_np=offsets_np,
|
||||||
|
neighbors_np=neighbors_np,
|
||||||
|
storage_fourcc=storage_fourcc,
|
||||||
|
storage_data=storage_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_hnsw_structure_from_file(path: str) -> HNSWComponents:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
return _read_hnsw_structure(f)
|
||||||
|
|
||||||
|
|
||||||
|
def write_original_format(
|
||||||
|
f_out,
|
||||||
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
offsets_np,
|
||||||
|
neighbors_np,
|
||||||
|
storage_fourcc,
|
||||||
|
storage_data,
|
||||||
|
):
|
||||||
|
"""Write non-compact HNSW data in original FAISS order."""
|
||||||
|
|
||||||
|
f_out.write(struct.pack("<I", original_hnsw_data["index_fourcc"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["d"]))
|
||||||
|
f_out.write(struct.pack("<q", original_hnsw_data["ntotal"]))
|
||||||
|
f_out.write(struct.pack("<q", original_hnsw_data["dummy1"]))
|
||||||
|
f_out.write(struct.pack("<q", original_hnsw_data["dummy2"]))
|
||||||
|
f_out.write(struct.pack("<?", original_hnsw_data["is_trained"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["metric_type"]))
|
||||||
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
|
f_out.write(struct.pack("<f", original_hnsw_data["metric_arg"]))
|
||||||
|
|
||||||
|
write_numpy_vector(f_out, assign_probas_np, "d")
|
||||||
|
write_numpy_vector(f_out, cum_nneighbor_per_level_np, "i")
|
||||||
|
write_numpy_vector(f_out, levels_np, "i")
|
||||||
|
|
||||||
|
write_numpy_vector(f_out, offsets_np, "Q")
|
||||||
|
write_numpy_vector(f_out, neighbors_np, "i")
|
||||||
|
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["entry_point"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["max_level"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["efConstruction"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["efSearch"]))
|
||||||
|
f_out.write(struct.pack("<i", original_hnsw_data["dummy_upper_beam"]))
|
||||||
|
|
||||||
|
f_out.write(struct.pack("<I", storage_fourcc))
|
||||||
|
if storage_fourcc != NULL_INDEX_FOURCC and storage_data:
|
||||||
|
f_out.write(storage_data)
|
||||||
|
|
||||||
|
|
||||||
|
def prune_hnsw_embeddings(input_filename: str, output_filename: str) -> bool:
|
||||||
|
"""Rewrite an HNSW index while dropping the embedded storage section."""
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
with open(input_filename, "rb") as f_in, open(output_filename, "wb") as f_out:
|
||||||
|
original_hnsw_data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
hnsw_index_fourcc = read_struct(f_in, "<I")
|
||||||
|
if hnsw_index_fourcc not in EXPECTED_HNSW_FOURCCS:
|
||||||
|
print(
|
||||||
|
f"Error: Expected HNSW Index FourCC ({list(EXPECTED_HNSW_FOURCCS)}), got {hnsw_index_fourcc:08x}.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
original_hnsw_data["index_fourcc"] = hnsw_index_fourcc
|
||||||
|
original_hnsw_data["d"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["ntotal"] = read_struct(f_in, "<q")
|
||||||
|
original_hnsw_data["dummy1"] = read_struct(f_in, "<q")
|
||||||
|
original_hnsw_data["dummy2"] = read_struct(f_in, "<q")
|
||||||
|
original_hnsw_data["is_trained"] = read_struct(f_in, "?")
|
||||||
|
original_hnsw_data["metric_type"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["metric_arg"] = 0.0
|
||||||
|
if original_hnsw_data["metric_type"] > 1:
|
||||||
|
original_hnsw_data["metric_arg"] = read_struct(f_in, "<f")
|
||||||
|
|
||||||
|
assign_probas_np = read_numpy_vector(f_in, np.float64, "d")
|
||||||
|
cum_nneighbor_per_level_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
|
levels_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
|
|
||||||
|
ntotal = len(levels_np)
|
||||||
|
if ntotal != original_hnsw_data["ntotal"]:
|
||||||
|
original_hnsw_data["ntotal"] = ntotal
|
||||||
|
|
||||||
|
pos_before_compact = f_in.tell()
|
||||||
|
is_compact_flag = None
|
||||||
|
try:
|
||||||
|
is_compact_flag = read_struct(f_in, "<?")
|
||||||
|
except EOFError:
|
||||||
|
is_compact_flag = None
|
||||||
|
|
||||||
|
if is_compact_flag:
|
||||||
|
compact_level_ptr = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
|
compact_node_offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
|
|
||||||
|
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||||
|
|
||||||
|
_storage_fourcc = read_struct(f_in, "<I")
|
||||||
|
compact_neighbors_data_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
|
compact_neighbors_data = compact_neighbors_data_np.tolist()
|
||||||
|
_storage_data = f_in.read()
|
||||||
|
|
||||||
|
write_compact_format(
|
||||||
|
f_out,
|
||||||
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
compact_level_ptr,
|
||||||
|
compact_node_offsets_np,
|
||||||
|
compact_neighbors_data,
|
||||||
|
NULL_INDEX_FOURCC,
|
||||||
|
b"",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
f_in.seek(pos_before_compact)
|
||||||
|
|
||||||
|
pos_before_probe = f_in.tell()
|
||||||
|
try:
|
||||||
|
suspected_flag = read_struct(f_in, "<B")
|
||||||
|
if suspected_flag != 0x00:
|
||||||
|
f_in.seek(pos_before_probe)
|
||||||
|
except EOFError:
|
||||||
|
f_in.seek(pos_before_probe)
|
||||||
|
|
||||||
|
offsets_np = read_numpy_vector(f_in, np.uint64, "Q")
|
||||||
|
neighbors_np = read_numpy_vector(f_in, np.int32, "i")
|
||||||
|
|
||||||
|
original_hnsw_data["entry_point"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["max_level"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["efConstruction"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["efSearch"] = read_struct(f_in, "<i")
|
||||||
|
original_hnsw_data["dummy_upper_beam"] = read_struct(f_in, "<i")
|
||||||
|
|
||||||
|
_storage_fourcc = None
|
||||||
|
_storage_data = b""
|
||||||
|
try:
|
||||||
|
_storage_fourcc = read_struct(f_in, "<I")
|
||||||
|
_storage_data = f_in.read()
|
||||||
|
except EOFError:
|
||||||
|
_storage_fourcc = NULL_INDEX_FOURCC
|
||||||
|
|
||||||
|
write_original_format(
|
||||||
|
f_out,
|
||||||
|
original_hnsw_data,
|
||||||
|
assign_probas_np,
|
||||||
|
cum_nneighbor_per_level_np,
|
||||||
|
levels_np,
|
||||||
|
offsets_np,
|
||||||
|
neighbors_np,
|
||||||
|
NULL_INDEX_FOURCC,
|
||||||
|
b"",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[{time.time() - start_time:.2f}s] Pruned embeddings from {input_filename}")
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"Failed to prune embeddings: {exc}", file=sys.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# --- Main Conversion Logic ---
|
# --- Main Conversion Logic ---
|
||||||
|
|
||||||
|
|
||||||
@@ -700,6 +984,29 @@ def convert_hnsw_graph_to_csr(input_filename, output_filename, prune_embeddings=
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def prune_hnsw_embeddings_inplace(index_filename: str) -> bool:
|
||||||
|
"""Convenience wrapper to prune embeddings in-place."""
|
||||||
|
|
||||||
|
temp_path = f"{index_filename}.prune.tmp"
|
||||||
|
success = prune_hnsw_embeddings(index_filename, temp_path)
|
||||||
|
if success:
|
||||||
|
try:
|
||||||
|
os.replace(temp_path, index_filename)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive
|
||||||
|
logger.error(f"Failed to replace original index with pruned version: {exc}")
|
||||||
|
try:
|
||||||
|
os.remove(temp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
os.remove(temp_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return success
|
||||||
|
|
||||||
|
|
||||||
# --- Script Execution ---
|
# --- Script Execution ---
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
@@ -13,7 +14,7 @@ from leann.interface import (
|
|||||||
from leann.registry import register_backend
|
from leann.registry import register_backend
|
||||||
from leann.searcher_base import BaseSearcher
|
from leann.searcher_base import BaseSearcher
|
||||||
|
|
||||||
from .convert_to_csr import convert_hnsw_graph_to_csr
|
from .convert_to_csr import convert_hnsw_graph_to_csr, prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -91,6 +92,8 @@ class HNSWBuilder(LeannBackendBuilderInterface):
|
|||||||
|
|
||||||
if self.is_compact:
|
if self.is_compact:
|
||||||
self._convert_to_csr(index_file)
|
self._convert_to_csr(index_file)
|
||||||
|
elif self.is_recompute:
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
def _convert_to_csr(self, index_file: Path):
|
def _convert_to_csr(self, index_file: Path):
|
||||||
"""Convert built index to CSR format"""
|
"""Convert built index to CSR format"""
|
||||||
@@ -132,10 +135,10 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
if metric_enum is None:
|
if metric_enum is None:
|
||||||
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
raise ValueError(f"Unsupported distance_metric '{self.distance_metric}'.")
|
||||||
|
|
||||||
self.is_compact, self.is_pruned = (
|
backend_meta_kwargs = self.meta.get("backend_kwargs", {})
|
||||||
self.meta.get("is_compact", True),
|
self.is_compact = self.meta.get("is_compact", backend_meta_kwargs.get("is_compact", True))
|
||||||
self.meta.get("is_pruned", True),
|
default_pruned = backend_meta_kwargs.get("is_recompute", self.is_compact)
|
||||||
)
|
self.is_pruned = bool(self.meta.get("is_pruned", default_pruned))
|
||||||
|
|
||||||
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
index_file = self.index_dir / f"{self.index_path.stem}.index"
|
||||||
if not index_file.exists():
|
if not index_file.exists():
|
||||||
@@ -236,6 +239,7 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
distances = np.empty((batch_size_query, top_k), dtype=np.float32)
|
||||||
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
labels = np.empty((batch_size_query, top_k), dtype=np.int64)
|
||||||
|
|
||||||
|
search_time = time.time()
|
||||||
self._index.search(
|
self._index.search(
|
||||||
query.shape[0],
|
query.shape[0],
|
||||||
faiss.swig_ptr(query),
|
faiss.swig_ptr(query),
|
||||||
@@ -244,7 +248,8 @@ class HNSWSearcher(BaseSearcher):
|
|||||||
faiss.swig_ptr(labels),
|
faiss.swig_ptr(labels),
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
|
search_time = time.time() - search_time
|
||||||
|
logger.info(f" Search time in HNSWSearcher.search() backend: {search_time} seconds")
|
||||||
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
string_labels = [[str(int_label) for int_label in batch_labels] for batch_labels in labels]
|
||||||
|
|
||||||
return {"labels": string_labels, "distances": distances}
|
return {"labels": string_labels, "distances": distances}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import msgpack
|
import msgpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -24,14 +24,36 @@ logger = logging.getLogger(__name__)
|
|||||||
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
log_level = getattr(logging, LOG_LEVEL, logging.WARNING)
|
||||||
logger.setLevel(log_level)
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
# Ensure we have a handler if none exists
|
# Ensure we have handlers if none exist
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
handler = logging.StreamHandler()
|
stream_handler = logging.StreamHandler()
|
||||||
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
handler.setFormatter(formatter)
|
stream_handler.setFormatter(formatter)
|
||||||
logger.addHandler(handler)
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
|
log_path = os.getenv("LEANN_HNSW_LOG_PATH")
|
||||||
|
if log_path:
|
||||||
|
try:
|
||||||
|
file_handler = logging.FileHandler(log_path, mode="a", encoding="utf-8")
|
||||||
|
file_formatter = logging.Formatter(
|
||||||
|
"%(asctime)s - %(levelname)s - [pid=%(process)d] %(message)s"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(file_formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
except Exception as exc: # pragma: no cover - best effort logging
|
||||||
|
logger.warning(f"Failed to attach file handler for log path {log_path}: {exc}")
|
||||||
|
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
|
_RAW_PROVIDER_OPTIONS = os.getenv("LEANN_EMBEDDING_OPTIONS")
|
||||||
|
try:
|
||||||
|
PROVIDER_OPTIONS: dict[str, Any] = (
|
||||||
|
json.loads(_RAW_PROVIDER_OPTIONS) if _RAW_PROVIDER_OPTIONS else {}
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Failed to parse LEANN_EMBEDDING_OPTIONS; ignoring provider options")
|
||||||
|
PROVIDER_OPTIONS = {}
|
||||||
|
|
||||||
|
|
||||||
def create_hnsw_embedding_server(
|
def create_hnsw_embedding_server(
|
||||||
passages_file: Optional[str] = None,
|
passages_file: Optional[str] = None,
|
||||||
@@ -90,9 +112,7 @@ def create_hnsw_embedding_server(
|
|||||||
embedding_dim: int = int(meta.get("dimensions", 0))
|
embedding_dim: int = int(meta.get("dimensions", 0))
|
||||||
except Exception:
|
except Exception:
|
||||||
embedding_dim = 0
|
embedding_dim = 0
|
||||||
logger.info(
|
logger.info(f"Loaded PassageManager with {len(passages)} passages from metadata")
|
||||||
f"Loaded PassageManager with {len(passages.global_offset_map)} passages from metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
# (legacy ZMQ thread removed; using shutdown-capable server only)
|
||||||
|
|
||||||
@@ -140,7 +160,12 @@ def create_hnsw_embedding_server(
|
|||||||
):
|
):
|
||||||
last_request_type = "text"
|
last_request_type = "text"
|
||||||
last_request_length = len(request)
|
last_request_length = len(request)
|
||||||
embeddings = compute_embeddings(request, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
request,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
rep_socket.send(msgpack.packb(embeddings.tolist()))
|
||||||
e2e_end = time.time()
|
e2e_end = time.time()
|
||||||
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
logger.info(f"⏱️ Text embedding E2E time: {e2e_end - e2e_start:.6f}s")
|
||||||
@@ -189,7 +214,10 @@ def create_hnsw_embedding_server(
|
|||||||
if texts:
|
if texts:
|
||||||
try:
|
try:
|
||||||
embeddings = compute_embeddings(
|
embeddings = compute_embeddings(
|
||||||
texts, model_name, mode=embedding_mode
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
@@ -254,7 +282,12 @@ def create_hnsw_embedding_server(
|
|||||||
|
|
||||||
if texts:
|
if texts:
|
||||||
try:
|
try:
|
||||||
embeddings = compute_embeddings(texts, model_name, mode=embedding_mode)
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
provider_options=PROVIDER_OPTIONS,
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
f"Computed embeddings for {len(texts)} texts, shape: {embeddings.shape}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,10 +6,10 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-backend-hnsw"
|
name = "leann-backend-hnsw"
|
||||||
version = "0.3.0"
|
version = "0.3.4"
|
||||||
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
description = "Custom-built HNSW (Faiss) backend for the Leann toolkit."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"leann-core==0.3.0",
|
"leann-core==0.3.4",
|
||||||
"numpy",
|
"numpy",
|
||||||
"pyzmq>=23.0.0",
|
"pyzmq>=23.0.0",
|
||||||
"msgpack>=1.0.0",
|
"msgpack>=1.0.0",
|
||||||
|
|||||||
Submodule packages/leann-backend-hnsw/third_party/faiss updated: 4a2c0d67d3...1d51f0c074
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann-core"
|
name = "leann-core"
|
||||||
version = "0.3.0"
|
version = "0.3.4"
|
||||||
description = "Core API and plugin system for LEANN"
|
description = "Core API and plugin system for LEANN"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -6,18 +6,22 @@ with the correct, original embedding logic from the user's reference code.
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
from leann.interface import LeannBackendSearcherInterface
|
from leann.interface import LeannBackendSearcherInterface
|
||||||
|
|
||||||
from .chat import get_llm
|
from .chat import get_llm
|
||||||
from .interface import LeannBackendFactoryInterface
|
from .interface import LeannBackendFactoryInterface
|
||||||
|
from .metadata_filter import MetadataFilterEngine
|
||||||
from .registry import BACKEND_REGISTRY
|
from .registry import BACKEND_REGISTRY
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -35,6 +39,7 @@ def compute_embeddings(
|
|||||||
use_server: bool = True,
|
use_server: bool = True,
|
||||||
port: Optional[int] = None,
|
port: Optional[int] = None,
|
||||||
is_build=False,
|
is_build=False,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Computes embeddings using different backends.
|
Computes embeddings using different backends.
|
||||||
@@ -68,6 +73,7 @@ def compute_embeddings(
|
|||||||
model_name,
|
model_name,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
|
provider_options=provider_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -119,9 +125,13 @@ class PassageManager:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
self, passage_sources: list[dict[str, Any]], metadata_file_path: Optional[str] = None
|
||||||
):
|
):
|
||||||
self.offset_maps = {}
|
self.offset_maps: dict[str, dict[str, int]] = {}
|
||||||
self.passage_files = {}
|
self.passage_files: dict[str, str] = {}
|
||||||
self.global_offset_map = {} # Combined map for fast lookup
|
# Avoid materializing a single gigantic global map to reduce memory
|
||||||
|
# footprint on very large corpora (e.g., 60M+ passages). Instead, keep
|
||||||
|
# per-shard maps and do a lightweight per-shard lookup on demand.
|
||||||
|
self._total_count: int = 0
|
||||||
|
self.filter_engine = MetadataFilterEngine() # Initialize filter engine
|
||||||
|
|
||||||
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
# Derive index base name for standard sibling fallbacks, e.g., <index_name>.passages.*
|
||||||
index_name_base = None
|
index_name_base = None
|
||||||
@@ -142,12 +152,25 @@ class PassageManager:
|
|||||||
default_name: Optional[str],
|
default_name: Optional[str],
|
||||||
source_dict: dict[str, Any],
|
source_dict: dict[str, Any],
|
||||||
) -> list[Path]:
|
) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Build an ordered list of candidate paths. For relative paths specified in
|
||||||
|
metadata, prefer resolution relative to the metadata file directory first,
|
||||||
|
then fall back to CWD-based resolution, and finally to conventional
|
||||||
|
sibling defaults (e.g., <index_base>.passages.idx / .jsonl).
|
||||||
|
"""
|
||||||
candidates: list[Path] = []
|
candidates: list[Path] = []
|
||||||
# 1) Primary as-is (absolute or relative)
|
# 1) Primary path
|
||||||
if primary:
|
if primary:
|
||||||
p = Path(primary)
|
p = Path(primary)
|
||||||
candidates.append(p if p.is_absolute() else (Path.cwd() / p))
|
if p.is_absolute():
|
||||||
# 2) metadata-relative explicit relative key
|
candidates.append(p)
|
||||||
|
else:
|
||||||
|
# Prefer metadata-relative resolution for relative paths
|
||||||
|
if metadata_file_path:
|
||||||
|
candidates.append(Path(metadata_file_path).parent / p)
|
||||||
|
# Also consider CWD-relative as a fallback for legacy layouts
|
||||||
|
candidates.append(Path.cwd() / p)
|
||||||
|
# 2) metadata-relative explicit relative key (if present)
|
||||||
if metadata_file_path and source_dict.get(relative_key):
|
if metadata_file_path and source_dict.get(relative_key):
|
||||||
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
|
candidates.append(Path(metadata_file_path).parent / source_dict[relative_key])
|
||||||
# 3) metadata-relative standard sibling filename
|
# 3) metadata-relative standard sibling filename
|
||||||
@@ -177,23 +200,78 @@ class PassageManager:
|
|||||||
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
raise FileNotFoundError(f"Passage index file not found: {index_file}")
|
||||||
|
|
||||||
with open(index_file, "rb") as f:
|
with open(index_file, "rb") as f:
|
||||||
offset_map = pickle.load(f)
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
self.offset_maps[passage_file] = offset_map
|
self.offset_maps[passage_file] = offset_map
|
||||||
self.passage_files[passage_file] = passage_file
|
self.passage_files[passage_file] = passage_file
|
||||||
|
self._total_count += len(offset_map)
|
||||||
# Build global map for O(1) lookup
|
|
||||||
for passage_id, offset in offset_map.items():
|
|
||||||
self.global_offset_map[passage_id] = (passage_file, offset)
|
|
||||||
|
|
||||||
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
def get_passage(self, passage_id: str) -> dict[str, Any]:
|
||||||
if passage_id in self.global_offset_map:
|
# Fast path: check each shard map (there are typically few shards).
|
||||||
passage_file, offset = self.global_offset_map[passage_id]
|
# This avoids building a massive combined dict while keeping lookups
|
||||||
# Lazy file opening - only open when needed
|
# bounded by the number of shards.
|
||||||
|
for passage_file, offset_map in self.offset_maps.items():
|
||||||
|
try:
|
||||||
|
offset = offset_map[passage_id]
|
||||||
with open(passage_file, encoding="utf-8") as f:
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
f.seek(offset)
|
f.seek(offset)
|
||||||
return json.loads(f.readline())
|
return json.loads(f.readline())
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
raise KeyError(f"Passage ID not found: {passage_id}")
|
raise KeyError(f"Passage ID not found: {passage_id}")
|
||||||
|
|
||||||
|
def filter_search_results(
|
||||||
|
self,
|
||||||
|
search_results: list[SearchResult],
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]],
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Apply metadata filters to search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of SearchResult objects
|
||||||
|
metadata_filters: Filter specifications to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of SearchResult objects
|
||||||
|
"""
|
||||||
|
if not metadata_filters:
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
logger.debug(f"Applying metadata filters to {len(search_results)} results")
|
||||||
|
|
||||||
|
# Convert SearchResult objects to dictionaries for the filter engine
|
||||||
|
result_dicts = []
|
||||||
|
for result in search_results:
|
||||||
|
result_dicts.append(
|
||||||
|
{
|
||||||
|
"id": result.id,
|
||||||
|
"score": result.score,
|
||||||
|
"text": result.text,
|
||||||
|
"metadata": result.metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filters using the filter engine
|
||||||
|
filtered_dicts = self.filter_engine.apply_filters(result_dicts, metadata_filters)
|
||||||
|
|
||||||
|
# Convert back to SearchResult objects
|
||||||
|
filtered_results = []
|
||||||
|
for result_dict in filtered_dicts:
|
||||||
|
filtered_results.append(
|
||||||
|
SearchResult(
|
||||||
|
id=result_dict["id"],
|
||||||
|
score=result_dict["score"],
|
||||||
|
text=result_dict["text"],
|
||||||
|
metadata=result_dict["metadata"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Filtered results: {len(filtered_results)} remaining")
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self._total_count
|
||||||
|
|
||||||
|
|
||||||
class LeannBuilder:
|
class LeannBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -202,6 +280,7 @@ class LeannBuilder:
|
|||||||
embedding_model: str = "facebook/contriever",
|
embedding_model: str = "facebook/contriever",
|
||||||
dimensions: Optional[int] = None,
|
dimensions: Optional[int] = None,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
embedding_options: Optional[dict[str, Any]] = None,
|
||||||
**backend_kwargs,
|
**backend_kwargs,
|
||||||
):
|
):
|
||||||
self.backend_name = backend_name
|
self.backend_name = backend_name
|
||||||
@@ -224,6 +303,7 @@ class LeannBuilder:
|
|||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
self.dimensions = dimensions
|
self.dimensions = dimensions
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
|
self.embedding_options = embedding_options or {}
|
||||||
|
|
||||||
# Check if we need to use cosine distance for normalized embeddings
|
# Check if we need to use cosine distance for normalized embeddings
|
||||||
normalized_embeddings_models = {
|
normalized_embeddings_models = {
|
||||||
@@ -331,6 +411,7 @@ class LeannBuilder:
|
|||||||
self.embedding_model,
|
self.embedding_model,
|
||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)[0]
|
)[0]
|
||||||
)
|
)
|
||||||
path = Path(index_path)
|
path = Path(index_path)
|
||||||
@@ -370,6 +451,7 @@ class LeannBuilder:
|
|||||||
self.embedding_mode,
|
self.embedding_mode,
|
||||||
use_server=False,
|
use_server=False,
|
||||||
is_build=True,
|
is_build=True,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
string_ids = [chunk["id"] for chunk in self.chunks]
|
string_ids = [chunk["id"] for chunk in self.chunks]
|
||||||
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
current_backend_kwargs = {**self.backend_kwargs, "dimensions": self.dimensions}
|
||||||
@@ -396,14 +478,15 @@ class LeannBuilder:
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.embedding_options:
|
||||||
|
meta_data["embedding_options"] = self.embedding_options
|
||||||
|
|
||||||
# Add storage status flags for HNSW backend
|
# Add storage status flags for HNSW backend
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||||
meta_data["is_compact"] = is_compact
|
meta_data["is_compact"] = is_compact
|
||||||
meta_data["is_pruned"] = (
|
meta_data["is_pruned"] = bool(is_recompute)
|
||||||
is_compact and is_recompute
|
|
||||||
) # Pruned only if compact and recompute
|
|
||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
@@ -518,18 +601,166 @@ class LeannBuilder:
|
|||||||
"embeddings_source": str(embeddings_file),
|
"embeddings_source": str(embeddings_file),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.embedding_options:
|
||||||
|
meta_data["embedding_options"] = self.embedding_options
|
||||||
|
|
||||||
# Add storage status flags for HNSW backend
|
# Add storage status flags for HNSW backend
|
||||||
if self.backend_name == "hnsw":
|
if self.backend_name == "hnsw":
|
||||||
is_compact = self.backend_kwargs.get("is_compact", True)
|
is_compact = self.backend_kwargs.get("is_compact", True)
|
||||||
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
is_recompute = self.backend_kwargs.get("is_recompute", True)
|
||||||
meta_data["is_compact"] = is_compact
|
meta_data["is_compact"] = is_compact
|
||||||
meta_data["is_pruned"] = is_compact and is_recompute
|
meta_data["is_pruned"] = bool(is_recompute)
|
||||||
|
|
||||||
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
with open(leann_meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
|
|
||||||
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
logger.info(f"Index built successfully from precomputed embeddings: {index_path}")
|
||||||
|
|
||||||
|
def update_index(self, index_path: str):
|
||||||
|
"""Append new passages and vectors to an existing HNSW index."""
|
||||||
|
if not self.chunks:
|
||||||
|
raise ValueError("No new chunks provided for update.")
|
||||||
|
|
||||||
|
path = Path(index_path)
|
||||||
|
index_dir = path.parent
|
||||||
|
index_name = path.name
|
||||||
|
index_prefix = path.stem
|
||||||
|
|
||||||
|
meta_path = index_dir / f"{index_name}.meta.json"
|
||||||
|
passages_file = index_dir / f"{index_name}.passages.jsonl"
|
||||||
|
offset_file = index_dir / f"{index_name}.passages.idx"
|
||||||
|
index_file = index_dir / f"{index_prefix}.index"
|
||||||
|
|
||||||
|
if not meta_path.exists() or not passages_file.exists() or not offset_file.exists():
|
||||||
|
raise FileNotFoundError("Index metadata or passage files are missing; cannot update.")
|
||||||
|
if not index_file.exists():
|
||||||
|
raise FileNotFoundError(f"HNSW index file not found: {index_file}")
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
backend_name = meta.get("backend_name")
|
||||||
|
if backend_name != self.backend_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Index was built with backend '{backend_name}', cannot update with '{self.backend_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_backend_kwargs = meta.get("backend_kwargs", {})
|
||||||
|
index_is_compact = meta.get("is_compact", meta_backend_kwargs.get("is_compact", True))
|
||||||
|
if index_is_compact:
|
||||||
|
raise ValueError(
|
||||||
|
"Compact HNSW indices do not support in-place updates. Rebuild required."
|
||||||
|
)
|
||||||
|
|
||||||
|
distance_metric = meta_backend_kwargs.get(
|
||||||
|
"distance_metric", self.backend_kwargs.get("distance_metric", "mips")
|
||||||
|
).lower()
|
||||||
|
needs_recompute = bool(
|
||||||
|
meta.get("is_pruned")
|
||||||
|
or meta_backend_kwargs.get("is_recompute")
|
||||||
|
or self.backend_kwargs.get("is_recompute")
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(offset_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
existing_ids = set(offset_map.keys())
|
||||||
|
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
for chunk in self.chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
continue
|
||||||
|
metadata = chunk.setdefault("metadata", {})
|
||||||
|
passage_id = chunk.get("id") or metadata.get("id")
|
||||||
|
if passage_id and passage_id in existing_ids:
|
||||||
|
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
|
||||||
|
if not valid_chunks:
|
||||||
|
raise ValueError("No valid chunks to append.")
|
||||||
|
|
||||||
|
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts_to_embed,
|
||||||
|
self.embedding_model,
|
||||||
|
self.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
is_build=True,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_dim = embeddings.shape[1]
|
||||||
|
expected_dim = meta.get("dimensions")
|
||||||
|
if expected_dim is not None and expected_dim != embedding_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dimension mismatch during update: existing index uses {expected_dim}, got {embedding_dim}."
|
||||||
|
)
|
||||||
|
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
|
||||||
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
index = faiss.read_index(str(index_file))
|
||||||
|
if hasattr(index, "is_recompute"):
|
||||||
|
index.is_recompute = needs_recompute
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
if index.d != embedding_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"Existing index dimension ({index.d}) does not match new embeddings ({embedding_dim})."
|
||||||
|
)
|
||||||
|
|
||||||
|
base_id = index.ntotal
|
||||||
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
|
new_id = str(base_id + offset)
|
||||||
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
|
chunk["id"] = new_id
|
||||||
|
|
||||||
|
index.add(embeddings.shape[0], faiss.swig_ptr(embeddings))
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for chunk in valid_chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk.get("metadata", {}),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Appended %d passages to index '%s'. New total: %d",
|
||||||
|
len(valid_chunks),
|
||||||
|
index_path,
|
||||||
|
len(offset_map),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chunks.clear()
|
||||||
|
|
||||||
|
if needs_recompute:
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
|
|
||||||
class LeannSearcher:
|
class LeannSearcher:
|
||||||
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
def __init__(self, index_path: str, enable_warmup: bool = False, **backend_kwargs):
|
||||||
@@ -553,15 +784,20 @@ class LeannSearcher:
|
|||||||
self.embedding_model = self.meta_data["embedding_model"]
|
self.embedding_model = self.meta_data["embedding_model"]
|
||||||
# Support both old and new format
|
# Support both old and new format
|
||||||
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta_data.get("embedding_mode", "sentence-transformers")
|
||||||
|
self.embedding_options = self.meta_data.get("embedding_options", {})
|
||||||
# Delegate portability handling to PassageManager
|
# Delegate portability handling to PassageManager
|
||||||
self.passage_manager = PassageManager(
|
self.passage_manager = PassageManager(
|
||||||
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
self.meta_data.get("passage_sources", []), metadata_file_path=self.meta_path_str
|
||||||
)
|
)
|
||||||
|
# Preserve backend name for conditional parameter forwarding
|
||||||
|
self.backend_name = backend_name
|
||||||
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
backend_factory = BACKEND_REGISTRY.get(backend_name)
|
||||||
if backend_factory is None:
|
if backend_factory is None:
|
||||||
raise ValueError(f"Backend '{backend_name}' not found.")
|
raise ValueError(f"Backend '{backend_name}' not found.")
|
||||||
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
final_kwargs = {**self.meta_data.get("backend_kwargs", {}), **backend_kwargs}
|
||||||
final_kwargs["enable_warmup"] = enable_warmup
|
final_kwargs["enable_warmup"] = enable_warmup
|
||||||
|
if self.embedding_options:
|
||||||
|
final_kwargs.setdefault("embedding_options", self.embedding_options)
|
||||||
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
self.backend_impl: LeannBackendSearcherInterface = backend_factory.searcher(
|
||||||
index_path, **final_kwargs
|
index_path, **final_kwargs
|
||||||
)
|
)
|
||||||
@@ -576,15 +812,49 @@ class LeannSearcher:
|
|||||||
recompute_embeddings: bool = True,
|
recompute_embeddings: bool = True,
|
||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SearchResult]:
|
) -> list[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search for nearest neighbors with optional metadata filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text query to search for
|
||||||
|
top_k: Number of nearest neighbors to return
|
||||||
|
complexity: Search complexity/candidate list size, higher = more accurate but slower
|
||||||
|
beam_width: Number of parallel search paths/IO requests per iteration
|
||||||
|
prune_ratio: Ratio of neighbors to prune via approximate distance (0.0-1.0)
|
||||||
|
recompute_embeddings: Whether to fetch fresh embeddings from server vs use stored codes
|
||||||
|
pruning_strategy: Candidate selection strategy - "global" (default), "local", or "proportional"
|
||||||
|
expected_zmq_port: ZMQ port for embedding server communication
|
||||||
|
metadata_filters: Optional filters to apply to search results based on metadata.
|
||||||
|
Format: {"field_name": {"operator": value}}
|
||||||
|
Supported operators:
|
||||||
|
- Comparison: "==", "!=", "<", "<=", ">", ">="
|
||||||
|
- Membership: "in", "not_in"
|
||||||
|
- String: "contains", "starts_with", "ends_with"
|
||||||
|
Example: {"chapter": {"<=": 5}, "tags": {"in": ["fiction", "drama"]}}
|
||||||
|
**kwargs: Backend-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects with text, metadata, and similarity scores
|
||||||
|
"""
|
||||||
|
# Handle grep search
|
||||||
|
if use_grep:
|
||||||
|
return self._grep_search(query, top_k)
|
||||||
|
|
||||||
logger.info("🔍 LeannSearcher.search() called:")
|
logger.info("🔍 LeannSearcher.search() called:")
|
||||||
logger.info(f" Query: '{query}'")
|
logger.info(f" Query: '{query}'")
|
||||||
logger.info(f" Top_k: {top_k}")
|
logger.info(f" Top_k: {top_k}")
|
||||||
|
logger.info(f" Metadata filters: {metadata_filters}")
|
||||||
logger.info(f" Additional kwargs: {kwargs}")
|
logger.info(f" Additional kwargs: {kwargs}")
|
||||||
|
|
||||||
# Smart top_k detection and adjustment
|
# Smart top_k detection and adjustment
|
||||||
total_docs = len(self.passage_manager.global_offset_map)
|
# Use PassageManager length (sum of shard sizes) to avoid
|
||||||
|
# depending on a massive combined map
|
||||||
|
total_docs = len(self.passage_manager)
|
||||||
original_top_k = top_k
|
original_top_k = top_k
|
||||||
if top_k > total_docs:
|
if top_k > total_docs:
|
||||||
top_k = total_docs
|
top_k = total_docs
|
||||||
@@ -613,23 +883,33 @@ class LeannSearcher:
|
|||||||
use_server_if_available=recompute_embeddings,
|
use_server_if_available=recompute_embeddings,
|
||||||
zmq_port=zmq_port,
|
zmq_port=zmq_port,
|
||||||
)
|
)
|
||||||
# logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
logger.info(f" Generated embedding shape: {query_embedding.shape}")
|
||||||
# time.time() - start_time
|
embedding_time = time.time() - start_time
|
||||||
# logger.info(f" Embedding time: {embedding_time} seconds")
|
logger.info(f" Embedding time: {embedding_time} seconds")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
backend_search_kwargs: dict[str, Any] = {
|
||||||
|
"complexity": complexity,
|
||||||
|
"beam_width": beam_width,
|
||||||
|
"prune_ratio": prune_ratio,
|
||||||
|
"recompute_embeddings": recompute_embeddings,
|
||||||
|
"pruning_strategy": pruning_strategy,
|
||||||
|
"zmq_port": zmq_port,
|
||||||
|
}
|
||||||
|
# Only HNSW supports batching; forward conditionally
|
||||||
|
if self.backend_name == "hnsw":
|
||||||
|
backend_search_kwargs["batch_size"] = batch_size
|
||||||
|
|
||||||
|
# Merge any extra kwargs last
|
||||||
|
backend_search_kwargs.update(kwargs)
|
||||||
|
|
||||||
results = self.backend_impl.search(
|
results = self.backend_impl.search(
|
||||||
query_embedding,
|
query_embedding,
|
||||||
top_k,
|
top_k,
|
||||||
complexity=complexity,
|
**backend_search_kwargs,
|
||||||
beam_width=beam_width,
|
|
||||||
prune_ratio=prune_ratio,
|
|
||||||
recompute_embeddings=recompute_embeddings,
|
|
||||||
pruning_strategy=pruning_strategy,
|
|
||||||
zmq_port=zmq_port,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
search_time = time.time() - start_time
|
||||||
|
logger.info(f" Search time in search() LEANN searcher: {search_time} seconds")
|
||||||
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
logger.info(f" Backend returned: labels={len(results.get('labels', [[]])[0])} results")
|
||||||
|
|
||||||
enriched_results = []
|
enriched_results = []
|
||||||
@@ -668,15 +948,109 @@ class LeannSearcher:
|
|||||||
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
f" {RED}✗{RESET} [{i + 1:2d}] ID: '{string_id}' -> {RED}ERROR: Passage not found!{RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apply metadata filters if specified
|
||||||
|
if metadata_filters:
|
||||||
|
logger.info(f" 🔍 Applying metadata filters: {metadata_filters}")
|
||||||
|
enriched_results = self.passage_manager.filter_search_results(
|
||||||
|
enriched_results, metadata_filters
|
||||||
|
)
|
||||||
|
|
||||||
# Define color codes outside the loop for final message
|
# Define color codes outside the loop for final message
|
||||||
GREEN = "\033[92m"
|
GREEN = "\033[92m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
logger.info(f" {GREEN}✓ Final enriched results: {len(enriched_results)} passages{RESET}")
|
||||||
return enriched_results
|
return enriched_results
|
||||||
|
|
||||||
|
def _find_jsonl_file(self) -> Optional[str]:
|
||||||
|
"""Find the .jsonl file containing raw passages for grep search"""
|
||||||
|
index_path = Path(self.meta_path_str).parent
|
||||||
|
potential_files = [
|
||||||
|
index_path / "documents.leann.passages.jsonl",
|
||||||
|
index_path.parent / "documents.leann.passages.jsonl",
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_path in potential_files:
|
||||||
|
if file_path.exists():
|
||||||
|
return str(file_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _grep_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Perform grep-based search on raw passages"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl passages file found for grep search")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = ["grep", "-i", "-n", query, jsonl_file]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
|
||||||
|
|
||||||
|
if result.returncode == 1:
|
||||||
|
return []
|
||||||
|
elif result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Grep failed: {result.stderr}")
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for line in result.stdout.strip().split("\n"):
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split(":", 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(parts[1])
|
||||||
|
text = data.get("text", "")
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", parts[0]),
|
||||||
|
text=text,
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(score),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"grep command not found. Please install grep or use semantic search."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _python_regex_search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
||||||
|
"""Fallback regex search"""
|
||||||
|
jsonl_file = self._find_jsonl_file()
|
||||||
|
if not jsonl_file:
|
||||||
|
raise FileNotFoundError("No .jsonl file found")
|
||||||
|
|
||||||
|
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
with open(jsonl_file, encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
if pattern.search(line):
|
||||||
|
try:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
matches.append(
|
||||||
|
SearchResult(
|
||||||
|
id=data.get("id", str(line_num)),
|
||||||
|
text=data.get("text", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
score=float(len(pattern.findall(data.get("text", "")))),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matches.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return matches[:top_k]
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
"""Explicitly cleanup embedding server resources.
|
"""Explicitly cleanup embedding server resources.
|
||||||
|
|
||||||
This method should be called after you're done using the searcher,
|
This method should be called after you're done using the searcher,
|
||||||
especially in test environments or batch processing scenarios.
|
especially in test environments or batch processing scenarios.
|
||||||
"""
|
"""
|
||||||
@@ -708,9 +1082,15 @@ class LeannChat:
|
|||||||
index_path: str,
|
index_path: str,
|
||||||
llm_config: Optional[dict[str, Any]] = None,
|
llm_config: Optional[dict[str, Any]] = None,
|
||||||
enable_warmup: bool = False,
|
enable_warmup: bool = False,
|
||||||
|
searcher: Optional[LeannSearcher] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if searcher is None:
|
||||||
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
self.searcher = LeannSearcher(index_path, enable_warmup=enable_warmup, **kwargs)
|
||||||
|
self._owns_searcher = True
|
||||||
|
else:
|
||||||
|
self.searcher = searcher
|
||||||
|
self._owns_searcher = False
|
||||||
self.llm = get_llm(llm_config)
|
self.llm = get_llm(llm_config)
|
||||||
|
|
||||||
def ask(
|
def ask(
|
||||||
@@ -724,6 +1104,9 @@ class LeannChat:
|
|||||||
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
pruning_strategy: Literal["global", "local", "proportional"] = "global",
|
||||||
llm_kwargs: Optional[dict[str, Any]] = None,
|
llm_kwargs: Optional[dict[str, Any]] = None,
|
||||||
expected_zmq_port: int = 5557,
|
expected_zmq_port: int = 5557,
|
||||||
|
metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None,
|
||||||
|
batch_size: int = 0,
|
||||||
|
use_grep: bool = False,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
):
|
):
|
||||||
if llm_kwargs is None:
|
if llm_kwargs is None:
|
||||||
@@ -738,10 +1121,12 @@ class LeannChat:
|
|||||||
recompute_embeddings=recompute_embeddings,
|
recompute_embeddings=recompute_embeddings,
|
||||||
pruning_strategy=pruning_strategy,
|
pruning_strategy=pruning_strategy,
|
||||||
expected_zmq_port=expected_zmq_port,
|
expected_zmq_port=expected_zmq_port,
|
||||||
|
metadata_filters=metadata_filters,
|
||||||
|
batch_size=batch_size,
|
||||||
**search_kwargs,
|
**search_kwargs,
|
||||||
)
|
)
|
||||||
search_time = time.time() - search_time
|
search_time = time.time() - search_time
|
||||||
# logger.info(f" Search time: {search_time} seconds")
|
logger.info(f" Search time: {search_time} seconds")
|
||||||
context = "\n\n".join([r.text for r in results])
|
context = "\n\n".join([r.text for r in results])
|
||||||
prompt = (
|
prompt = (
|
||||||
"Here is some retrieved context that might help answer your question:\n\n"
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
@@ -777,7 +1162,9 @@ class LeannChat:
|
|||||||
This method should be called after you're done using the chat interface,
|
This method should be called after you're done using the chat interface,
|
||||||
especially in test environments or batch processing scenarios.
|
especially in test environments or batch processing scenarios.
|
||||||
"""
|
"""
|
||||||
if hasattr(self.searcher, "cleanup"):
|
# Only stop the embedding server if this LeannChat instance created the searcher.
|
||||||
|
# When a shared searcher is passed in, avoid shutting down the server to enable reuse.
|
||||||
|
if getattr(self, "_owns_searcher", False) and hasattr(self.searcher, "cleanup"):
|
||||||
self.searcher.cleanup()
|
self.searcher.cleanup()
|
||||||
|
|
||||||
# Enable automatic cleanup patterns
|
# Enable automatic cleanup patterns
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -310,11 +312,12 @@ def search_hf_models(query: str, limit: int = 10) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def validate_model_and_suggest(
|
def validate_model_and_suggest(
|
||||||
model_name: str, llm_type: str, host: str = "http://localhost:11434"
|
model_name: str, llm_type: str, host: Optional[str] = None
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Validate model name and provide suggestions if invalid"""
|
"""Validate model name and provide suggestions if invalid"""
|
||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
available_models = check_ollama_models(host)
|
resolved_host = resolve_ollama_host(host)
|
||||||
|
available_models = check_ollama_models(resolved_host)
|
||||||
if available_models and model_name not in available_models:
|
if available_models and model_name not in available_models:
|
||||||
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
error_msg = f"Model '{model_name}' not found in your local Ollama installation."
|
||||||
|
|
||||||
@@ -457,19 +460,19 @@ class LLMInterface(ABC):
|
|||||||
class OllamaChat(LLMInterface):
|
class OllamaChat(LLMInterface):
|
||||||
"""LLM interface for Ollama models."""
|
"""LLM interface for Ollama models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "llama3:8b", host: str = "http://localhost:11434"):
|
def __init__(self, model: str = "llama3:8b", host: Optional[str] = None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.host = host
|
self.host = resolve_ollama_host(host)
|
||||||
logger.info(f"Initializing OllamaChat with model='{model}' and host='{host}'")
|
logger.info(f"Initializing OllamaChat with model='{model}' and host='{self.host}'")
|
||||||
try:
|
try:
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# Check if the Ollama server is responsive
|
# Check if the Ollama server is responsive
|
||||||
if host:
|
if self.host:
|
||||||
requests.get(host)
|
requests.get(self.host)
|
||||||
|
|
||||||
# Pre-check model availability with helpful suggestions
|
# Pre-check model availability with helpful suggestions
|
||||||
model_error = validate_model_and_suggest(model, "ollama", host)
|
model_error = validate_model_and_suggest(model, "ollama", self.host)
|
||||||
if model_error:
|
if model_error:
|
||||||
raise ValueError(model_error)
|
raise ValueError(model_error)
|
||||||
|
|
||||||
@@ -478,9 +481,11 @@ class OllamaChat(LLMInterface):
|
|||||||
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
"The 'requests' library is required for Ollama. Please install it with 'pip install requests'."
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
logger.error(f"Could not connect to Ollama at {host}. Please ensure Ollama is running.")
|
logger.error(
|
||||||
|
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||||
|
)
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
f"Could not connect to Ollama at {host}. Please ensure Ollama is running."
|
f"Could not connect to Ollama at {self.host}. Please ensure Ollama is running."
|
||||||
)
|
)
|
||||||
|
|
||||||
def ask(self, prompt: str, **kwargs) -> str:
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
@@ -737,21 +742,31 @@ class GeminiChat(LLMInterface):
|
|||||||
class OpenAIChat(LLMInterface):
|
class OpenAIChat(LLMInterface):
|
||||||
"""LLM interface for OpenAI models."""
|
"""LLM interface for OpenAI models."""
|
||||||
|
|
||||||
def __init__(self, model: str = "gpt-4o", api_key: Optional[str] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.base_url = resolve_openai_base_url(base_url)
|
||||||
|
self.api_key = resolve_openai_api_key(api_key)
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Initializing OpenAI Chat with model='{model}'")
|
logger.info(
|
||||||
|
"Initializing OpenAI Chat with model='%s' and base_url='%s'",
|
||||||
|
model,
|
||||||
|
self.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
self.client = openai.OpenAI(api_key=self.api_key)
|
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
"The 'openai' library is required for OpenAI models. Please install it with 'pip install openai'."
|
||||||
@@ -841,12 +856,16 @@ def get_llm(llm_config: Optional[dict[str, Any]] = None) -> LLMInterface:
|
|||||||
if llm_type == "ollama":
|
if llm_type == "ollama":
|
||||||
return OllamaChat(
|
return OllamaChat(
|
||||||
model=model or "llama3:8b",
|
model=model or "llama3:8b",
|
||||||
host=llm_config.get("host", "http://localhost:11434"),
|
host=llm_config.get("host"),
|
||||||
)
|
)
|
||||||
elif llm_type == "hf":
|
elif llm_type == "hf":
|
||||||
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
return HFChat(model_name=model or "deepseek-ai/deepseek-llm-7b-chat")
|
||||||
elif llm_type == "openai":
|
elif llm_type == "openai":
|
||||||
return OpenAIChat(model=model or "gpt-4o", api_key=llm_config.get("api_key"))
|
return OpenAIChat(
|
||||||
|
model=model or "gpt-4o",
|
||||||
|
api_key=llm_config.get("api_key"),
|
||||||
|
base_url=llm_config.get("base_url"),
|
||||||
|
)
|
||||||
elif llm_type == "gemini":
|
elif llm_type == "gemini":
|
||||||
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
return GeminiChat(model=model or "gemini-2.5-flash", api_key=llm_config.get("api_key"))
|
||||||
elif llm_type == "simulated":
|
elif llm_type == "simulated":
|
||||||
|
|||||||
220
packages/leann-core/src/leann/chunking_utils.py
Normal file
220
packages/leann-core/src/leann/chunking_utils.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Enhanced chunking utilities with AST-aware code chunking support.
|
||||||
|
Packaged within leann-core so installed wheels can import it reliably.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Code file extensions supported by astchunk
|
||||||
|
CODE_EXTENSIONS = {
|
||||||
|
".py": "python",
|
||||||
|
".java": "java",
|
||||||
|
".cs": "csharp",
|
||||||
|
".ts": "typescript",
|
||||||
|
".tsx": "typescript",
|
||||||
|
".js": "typescript",
|
||||||
|
".jsx": "typescript",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def detect_code_files(documents, code_extensions=None) -> tuple[list, list]:
|
||||||
|
"""Separate documents into code files and regular text files."""
|
||||||
|
if code_extensions is None:
|
||||||
|
code_extensions = CODE_EXTENSIONS
|
||||||
|
|
||||||
|
code_docs = []
|
||||||
|
text_docs = []
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "") or doc.metadata.get("file_name", "")
|
||||||
|
if file_path:
|
||||||
|
file_ext = Path(file_path).suffix.lower()
|
||||||
|
if file_ext in code_extensions:
|
||||||
|
doc.metadata["language"] = code_extensions[file_ext]
|
||||||
|
doc.metadata["is_code"] = True
|
||||||
|
code_docs.append(doc)
|
||||||
|
else:
|
||||||
|
doc.metadata["is_code"] = False
|
||||||
|
text_docs.append(doc)
|
||||||
|
else:
|
||||||
|
doc.metadata["is_code"] = False
|
||||||
|
text_docs.append(doc)
|
||||||
|
|
||||||
|
logger.info(f"Detected {len(code_docs)} code files and {len(text_docs)} text files")
|
||||||
|
return code_docs, text_docs
|
||||||
|
|
||||||
|
|
||||||
|
def get_language_from_extension(file_path: str) -> Optional[str]:
|
||||||
|
"""Return language string from a filename/extension using CODE_EXTENSIONS."""
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
return CODE_EXTENSIONS.get(ext)
|
||||||
|
|
||||||
|
|
||||||
|
def create_ast_chunks(
|
||||||
|
documents,
|
||||||
|
max_chunk_size: int = 512,
|
||||||
|
chunk_overlap: int = 64,
|
||||||
|
metadata_template: str = "default",
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create AST-aware chunks from code documents using astchunk.
|
||||||
|
|
||||||
|
Falls back to traditional chunking if astchunk is unavailable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from astchunk import ASTChunkBuilder # optional dependency
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"astchunk not available: {e}")
|
||||||
|
logger.info("Falling back to traditional chunking for code files")
|
||||||
|
return create_traditional_chunks(documents, max_chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
for doc in documents:
|
||||||
|
language = doc.metadata.get("language")
|
||||||
|
if not language:
|
||||||
|
logger.warning("No language detected; falling back to traditional chunking")
|
||||||
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
configs = {
|
||||||
|
"max_chunk_size": max_chunk_size,
|
||||||
|
"language": language,
|
||||||
|
"metadata_template": metadata_template,
|
||||||
|
"chunk_overlap": chunk_overlap if chunk_overlap > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
repo_metadata = {
|
||||||
|
"file_path": doc.metadata.get("file_path", ""),
|
||||||
|
"file_name": doc.metadata.get("file_name", ""),
|
||||||
|
"creation_date": doc.metadata.get("creation_date", ""),
|
||||||
|
"last_modified_date": doc.metadata.get("last_modified_date", ""),
|
||||||
|
}
|
||||||
|
configs["repo_level_metadata"] = repo_metadata
|
||||||
|
|
||||||
|
chunk_builder = ASTChunkBuilder(**configs)
|
||||||
|
code_content = doc.get_content()
|
||||||
|
if not code_content or not code_content.strip():
|
||||||
|
logger.warning("Empty code content, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunks = chunk_builder.chunkify(code_content)
|
||||||
|
for chunk in chunks:
|
||||||
|
if hasattr(chunk, "text"):
|
||||||
|
chunk_text = chunk.text
|
||||||
|
elif isinstance(chunk, dict) and "text" in chunk:
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
chunk_text = chunk
|
||||||
|
else:
|
||||||
|
chunk_text = str(chunk)
|
||||||
|
|
||||||
|
if chunk_text and chunk_text.strip():
|
||||||
|
all_chunks.append(chunk_text.strip())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created {len(chunks)} AST chunks from {language} file: {doc.metadata.get('file_name', 'unknown')}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"AST chunking failed for {language} file: {e}")
|
||||||
|
logger.info("Falling back to traditional chunking")
|
||||||
|
all_chunks.extend(create_traditional_chunks([doc], max_chunk_size, chunk_overlap))
|
||||||
|
|
||||||
|
return all_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def create_traditional_chunks(
|
||||||
|
documents, chunk_size: int = 256, chunk_overlap: int = 128
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create traditional text chunks using LlamaIndex SentenceSplitter."""
|
||||||
|
if chunk_size <= 0:
|
||||||
|
logger.warning(f"Invalid chunk_size={chunk_size}, using default value of 256")
|
||||||
|
chunk_size = 256
|
||||||
|
if chunk_overlap < 0:
|
||||||
|
chunk_overlap = 0
|
||||||
|
if chunk_overlap >= chunk_size:
|
||||||
|
chunk_overlap = chunk_size // 2
|
||||||
|
|
||||||
|
node_parser = SentenceSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
separator=" ",
|
||||||
|
paragraph_separator="\n\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
for doc in documents:
|
||||||
|
try:
|
||||||
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
|
if nodes:
|
||||||
|
all_texts.extend(node.get_content() for node in nodes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Traditional chunking failed for document: {e}")
|
||||||
|
content = doc.get_content()
|
||||||
|
if content and content.strip():
|
||||||
|
all_texts.append(content.strip())
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
def create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size: int = 256,
|
||||||
|
chunk_overlap: int = 128,
|
||||||
|
use_ast_chunking: bool = False,
|
||||||
|
ast_chunk_size: int = 512,
|
||||||
|
ast_chunk_overlap: int = 64,
|
||||||
|
code_file_extensions: Optional[list[str]] = None,
|
||||||
|
ast_fallback_traditional: bool = True,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Create text chunks from documents with optional AST support for code files."""
|
||||||
|
if not documents:
|
||||||
|
logger.warning("No documents provided for chunking")
|
||||||
|
return []
|
||||||
|
|
||||||
|
local_code_extensions = CODE_EXTENSIONS.copy()
|
||||||
|
if code_file_extensions:
|
||||||
|
ext_mapping = {
|
||||||
|
".py": "python",
|
||||||
|
".java": "java",
|
||||||
|
".cs": "c_sharp",
|
||||||
|
".ts": "typescript",
|
||||||
|
".tsx": "typescript",
|
||||||
|
}
|
||||||
|
for ext in code_file_extensions:
|
||||||
|
if ext.lower() not in local_code_extensions:
|
||||||
|
if ext.lower() in ext_mapping:
|
||||||
|
local_code_extensions[ext.lower()] = ext_mapping[ext.lower()]
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported extension {ext}, will use traditional chunking")
|
||||||
|
|
||||||
|
all_chunks = []
|
||||||
|
if use_ast_chunking:
|
||||||
|
code_docs, text_docs = detect_code_files(documents, local_code_extensions)
|
||||||
|
if code_docs:
|
||||||
|
try:
|
||||||
|
all_chunks.extend(
|
||||||
|
create_ast_chunks(
|
||||||
|
code_docs, max_chunk_size=ast_chunk_size, chunk_overlap=ast_chunk_overlap
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AST chunking failed: {e}")
|
||||||
|
if ast_fallback_traditional:
|
||||||
|
all_chunks.extend(
|
||||||
|
create_traditional_chunks(code_docs, chunk_size, chunk_overlap)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
if text_docs:
|
||||||
|
all_chunks.extend(create_traditional_chunks(text_docs, chunk_size, chunk_overlap))
|
||||||
|
else:
|
||||||
|
all_chunks = create_traditional_chunks(documents, chunk_size, chunk_overlap)
|
||||||
|
|
||||||
|
logger.info(f"Total chunks created: {len(all_chunks)}")
|
||||||
|
return all_chunks
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
@@ -9,6 +9,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from .api import LeannBuilder, LeannChat, LeannSearcher
|
from .api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
from .registry import register_project_directory
|
from .registry import register_project_directory
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
|
||||||
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
def extract_pdf_text_with_pymupdf(file_path: str) -> str:
|
||||||
@@ -123,6 +124,24 @@ Examples:
|
|||||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
help="Embedding backend mode (default: sentence-transformers)",
|
help="Embedding backend mode (default: sentence-transformers)",
|
||||||
)
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
build_parser.add_argument(
|
build_parser.add_argument(
|
||||||
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
"--force", "-f", action="store_true", help="Force rebuild existing index"
|
||||||
)
|
)
|
||||||
@@ -180,6 +199,29 @@ Examples:
|
|||||||
default=50,
|
default=50,
|
||||||
help="Code chunk overlap (default: 50)",
|
help="Code chunk overlap (default: 50)",
|
||||||
)
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--use-ast-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--ast-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="AST chunk size in characters (default: 768)",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--ast-chunk-overlap",
|
||||||
|
type=int,
|
||||||
|
default=96,
|
||||||
|
help="AST chunk overlap in characters (default: 96)",
|
||||||
|
)
|
||||||
|
build_parser.add_argument(
|
||||||
|
"--ast-fallback-traditional",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
# Search command
|
# Search command
|
||||||
search_parser = subparsers.add_parser("search", help="Search documents")
|
search_parser = subparsers.add_parser("search", help="Search documents")
|
||||||
@@ -206,10 +248,20 @@ Examples:
|
|||||||
default="global",
|
default="global",
|
||||||
help="Pruning strategy (default: global)",
|
help="Pruning strategy (default: global)",
|
||||||
)
|
)
|
||||||
|
search_parser.add_argument(
|
||||||
|
"--non-interactive",
|
||||||
|
action="store_true",
|
||||||
|
help="Non-interactive mode: automatically select index without prompting",
|
||||||
|
)
|
||||||
|
|
||||||
# Ask command
|
# Ask command
|
||||||
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
ask_parser = subparsers.add_parser("ask", help="Ask questions")
|
||||||
ask_parser.add_argument("index_name", help="Index name")
|
ask_parser.add_argument("index_name", help="Index name")
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"query",
|
||||||
|
nargs="?",
|
||||||
|
help="Question to ask (omit for prompt or when using --interactive)",
|
||||||
|
)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--llm",
|
"--llm",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -220,7 +272,12 @@ Examples:
|
|||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
"--model", type=str, default="qwen3:8b", help="Model name (default: qwen3:8b)"
|
||||||
)
|
)
|
||||||
ask_parser.add_argument("--host", type=str, default="http://localhost:11434")
|
ask_parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible host (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
|
)
|
||||||
ask_parser.add_argument(
|
ask_parser.add_argument(
|
||||||
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
"--interactive", "-i", action="store_true", help="Interactive chat mode"
|
||||||
)
|
)
|
||||||
@@ -249,6 +306,18 @@ Examples:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
)
|
)
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs (e.g., http://localhost:10000/v1)",
|
||||||
|
)
|
||||||
|
ask_parser.add_argument(
|
||||||
|
"--api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
# List command
|
# List command
|
||||||
subparsers.add_parser("list", help="List all indexes")
|
subparsers.add_parser("list", help="List all indexes")
|
||||||
@@ -293,9 +362,17 @@ Examples:
|
|||||||
|
|
||||||
return basic_matches
|
return basic_matches
|
||||||
|
|
||||||
def _should_exclude_file(self, relative_path: Path, gitignore_matches) -> bool:
|
def _should_exclude_file(self, file_path: Path, gitignore_matches) -> bool:
|
||||||
"""Check if a file should be excluded using gitignore parser."""
|
"""Check if a file should be excluded using gitignore parser.
|
||||||
return gitignore_matches(str(relative_path))
|
|
||||||
|
Always match against absolute, posix-style paths for consistency with
|
||||||
|
gitignore_parser expectations.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
absolute_path = file_path.resolve()
|
||||||
|
except Exception:
|
||||||
|
absolute_path = Path(str(file_path))
|
||||||
|
return gitignore_matches(absolute_path.as_posix())
|
||||||
|
|
||||||
def _is_git_submodule(self, path: Path) -> bool:
|
def _is_git_submodule(self, path: Path) -> bool:
|
||||||
"""Check if a path is a git submodule."""
|
"""Check if a path is a git submodule."""
|
||||||
@@ -367,7 +444,9 @@ Examples:
|
|||||||
print(f" {current_path}")
|
print(f" {current_path}")
|
||||||
print(" " + "─" * 45)
|
print(" " + "─" * 45)
|
||||||
|
|
||||||
current_indexes = self._discover_indexes_in_project(current_path)
|
current_indexes = self._discover_indexes_in_project(
|
||||||
|
current_path, exclude_dirs=other_projects
|
||||||
|
)
|
||||||
if current_indexes:
|
if current_indexes:
|
||||||
for idx in current_indexes:
|
for idx in current_indexes:
|
||||||
total_indexes += 1
|
total_indexes += 1
|
||||||
@@ -405,14 +484,15 @@ Examples:
|
|||||||
print("💡 Get started:")
|
print("💡 Get started:")
|
||||||
print(" leann build my-docs --docs ./documents")
|
print(" leann build my-docs --docs ./documents")
|
||||||
else:
|
else:
|
||||||
projects_count = len(
|
# Count only projects that have at least one discoverable index
|
||||||
[
|
projects_count = 0
|
||||||
p
|
for p in valid_projects:
|
||||||
for p in valid_projects
|
if p == current_path:
|
||||||
if (p / ".leann" / "indexes").exists()
|
discovered = self._discover_indexes_in_project(p, exclude_dirs=other_projects)
|
||||||
and list((p / ".leann" / "indexes").iterdir())
|
else:
|
||||||
]
|
discovered = self._discover_indexes_in_project(p)
|
||||||
)
|
if len(discovered) > 0:
|
||||||
|
projects_count += 1
|
||||||
print(f"📊 Total: {total_indexes} indexes across {projects_count} projects")
|
print(f"📊 Total: {total_indexes} indexes across {projects_count} projects")
|
||||||
|
|
||||||
if current_indexes_count > 0:
|
if current_indexes_count > 0:
|
||||||
@@ -429,9 +509,22 @@ Examples:
|
|||||||
print("\n💡 Create your first index:")
|
print("\n💡 Create your first index:")
|
||||||
print(" leann build my-docs --docs ./documents")
|
print(" leann build my-docs --docs ./documents")
|
||||||
|
|
||||||
def _discover_indexes_in_project(self, project_path: Path):
|
def _discover_indexes_in_project(
|
||||||
"""Discover all indexes in a project directory (both CLI and apps formats)"""
|
self, project_path: Path, exclude_dirs: Optional[list[Path]] = None
|
||||||
|
):
|
||||||
|
"""Discover all indexes in a project directory (both CLI and apps formats)
|
||||||
|
|
||||||
|
exclude_dirs: when provided, skip any APP-format index files that are
|
||||||
|
located under these directories. This prevents duplicates when the
|
||||||
|
current project is a parent directory of other registered projects.
|
||||||
|
"""
|
||||||
indexes = []
|
indexes = []
|
||||||
|
exclude_dirs = exclude_dirs or []
|
||||||
|
# normalize to resolved paths once for comparison
|
||||||
|
try:
|
||||||
|
exclude_dirs_resolved = [p.resolve() for p in exclude_dirs]
|
||||||
|
except Exception:
|
||||||
|
exclude_dirs_resolved = exclude_dirs
|
||||||
|
|
||||||
# 1. CLI format: .leann/indexes/index_name/
|
# 1. CLI format: .leann/indexes/index_name/
|
||||||
cli_indexes_dir = project_path / ".leann" / "indexes"
|
cli_indexes_dir = project_path / ".leann" / "indexes"
|
||||||
@@ -461,26 +554,46 @@ Examples:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 2. Apps format: *.leann.meta.json files anywhere in the project
|
# 2. Apps format: *.leann.meta.json files anywhere in the project
|
||||||
|
cli_indexes_dir = project_path / ".leann" / "indexes"
|
||||||
for meta_file in project_path.rglob("*.leann.meta.json"):
|
for meta_file in project_path.rglob("*.leann.meta.json"):
|
||||||
if meta_file.is_file():
|
if meta_file.is_file():
|
||||||
# Extract index name from filename (remove .leann.meta.json extension)
|
# Skip CLI-built indexes (which store meta under .leann/indexes/<name>/)
|
||||||
index_name = meta_file.name.replace(".leann.meta.json", "")
|
try:
|
||||||
|
if cli_indexes_dir.exists() and cli_indexes_dir in meta_file.parents:
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Skip meta files that live under excluded directories
|
||||||
|
try:
|
||||||
|
meta_parent_resolved = meta_file.parent.resolve()
|
||||||
|
if any(
|
||||||
|
meta_parent_resolved.is_relative_to(ex_dir)
|
||||||
|
for ex_dir in exclude_dirs_resolved
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
# best effort; if resolve or comparison fails, do not exclude
|
||||||
|
pass
|
||||||
|
# Use the parent directory name as the app index display name
|
||||||
|
display_name = meta_file.parent.name
|
||||||
|
# Extract file base used to store files
|
||||||
|
file_base = meta_file.name.replace(".leann.meta.json", "")
|
||||||
|
|
||||||
# Apps indexes are considered complete if the .leann.meta.json file exists
|
# Apps indexes are considered complete if the .leann.meta.json file exists
|
||||||
status = "✅"
|
status = "✅"
|
||||||
|
|
||||||
# Calculate total size of all related files
|
# Calculate total size of all related files (use file base)
|
||||||
size_mb = 0
|
size_mb = 0
|
||||||
try:
|
try:
|
||||||
index_dir = meta_file.parent
|
index_dir = meta_file.parent
|
||||||
for related_file in index_dir.glob(f"{index_name}.leann*"):
|
for related_file in index_dir.glob(f"{file_base}.leann*"):
|
||||||
size_mb += related_file.stat().st_size / (1024 * 1024)
|
size_mb += related_file.stat().st_size / (1024 * 1024)
|
||||||
except (OSError, PermissionError):
|
except (OSError, PermissionError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
indexes.append(
|
indexes.append(
|
||||||
{
|
{
|
||||||
"name": index_name,
|
"name": display_name,
|
||||||
"type": "app",
|
"type": "app",
|
||||||
"status": status,
|
"status": status,
|
||||||
"size_mb": size_mb,
|
"size_mb": size_mb,
|
||||||
@@ -534,11 +647,77 @@ Examples:
|
|||||||
if not project_path.exists():
|
if not project_path.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 1) CLI-format index under .leann/indexes/<name>
|
||||||
index_dir = project_path / ".leann" / "indexes" / index_name
|
index_dir = project_path / ".leann" / "indexes" / index_name
|
||||||
if index_dir.exists():
|
if index_dir.exists():
|
||||||
is_current = project_path == current_path
|
is_current = project_path == current_path
|
||||||
matches.append(
|
matches.append(
|
||||||
{"project_path": project_path, "index_dir": index_dir, "is_current": is_current}
|
{
|
||||||
|
"project_path": project_path,
|
||||||
|
"index_dir": index_dir,
|
||||||
|
"is_current": is_current,
|
||||||
|
"kind": "cli",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) App-format indexes
|
||||||
|
# We support two ways of addressing apps:
|
||||||
|
# a) by the file base (e.g., `pdf_documents`)
|
||||||
|
# b) by the parent directory name (e.g., `new_txt`)
|
||||||
|
seen_app_meta = set()
|
||||||
|
|
||||||
|
# 2a) by file base
|
||||||
|
for meta_file in project_path.rglob(f"{index_name}.leann.meta.json"):
|
||||||
|
if meta_file.is_file():
|
||||||
|
# Skip CLI-built indexes' meta under .leann/indexes
|
||||||
|
try:
|
||||||
|
cli_indexes_dir = project_path / ".leann" / "indexes"
|
||||||
|
if cli_indexes_dir.exists() and cli_indexes_dir in meta_file.parents:
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
is_current = project_path == current_path
|
||||||
|
key = (str(project_path), str(meta_file))
|
||||||
|
if key in seen_app_meta:
|
||||||
|
continue
|
||||||
|
seen_app_meta.add(key)
|
||||||
|
matches.append(
|
||||||
|
{
|
||||||
|
"project_path": project_path,
|
||||||
|
"files_dir": meta_file.parent,
|
||||||
|
"meta_file": meta_file,
|
||||||
|
"is_current": is_current,
|
||||||
|
"kind": "app",
|
||||||
|
"display_name": meta_file.parent.name,
|
||||||
|
"file_base": meta_file.name.replace(".leann.meta.json", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2b) by parent directory name
|
||||||
|
for meta_file in project_path.rglob("*.leann.meta.json"):
|
||||||
|
if meta_file.is_file() and meta_file.parent.name == index_name:
|
||||||
|
# Skip CLI-built indexes' meta under .leann/indexes
|
||||||
|
try:
|
||||||
|
cli_indexes_dir = project_path / ".leann" / "indexes"
|
||||||
|
if cli_indexes_dir.exists() and cli_indexes_dir in meta_file.parents:
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
is_current = project_path == current_path
|
||||||
|
key = (str(project_path), str(meta_file))
|
||||||
|
if key in seen_app_meta:
|
||||||
|
continue
|
||||||
|
seen_app_meta.add(key)
|
||||||
|
matches.append(
|
||||||
|
{
|
||||||
|
"project_path": project_path,
|
||||||
|
"files_dir": meta_file.parent,
|
||||||
|
"meta_file": meta_file,
|
||||||
|
"is_current": is_current,
|
||||||
|
"kind": "app",
|
||||||
|
"display_name": meta_file.parent.name,
|
||||||
|
"file_base": meta_file.name.replace(".leann.meta.json", ""),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sort: current project first, then by project name
|
# Sort: current project first, then by project name
|
||||||
@@ -548,8 +727,8 @@ Examples:
|
|||||||
def _remove_single_match(self, match, index_name: str, force: bool):
|
def _remove_single_match(self, match, index_name: str, force: bool):
|
||||||
"""Handle removal when only one match is found"""
|
"""Handle removal when only one match is found"""
|
||||||
project_path = match["project_path"]
|
project_path = match["project_path"]
|
||||||
index_dir = match["index_dir"]
|
|
||||||
is_current = match["is_current"]
|
is_current = match["is_current"]
|
||||||
|
kind = match.get("kind", "cli")
|
||||||
|
|
||||||
if is_current:
|
if is_current:
|
||||||
location_info = "current project"
|
location_info = "current project"
|
||||||
@@ -560,7 +739,10 @@ Examples:
|
|||||||
|
|
||||||
print(f"✅ Found 1 index named '{index_name}':")
|
print(f"✅ Found 1 index named '{index_name}':")
|
||||||
print(f" {emoji} Location: {location_info}")
|
print(f" {emoji} Location: {location_info}")
|
||||||
print(f" 📍 Path: {project_path}")
|
if kind == "cli":
|
||||||
|
print(f" 📍 Path: {project_path / '.leann' / 'indexes' / index_name}")
|
||||||
|
else:
|
||||||
|
print(f" 📍 Meta: {match['meta_file']}")
|
||||||
|
|
||||||
if not force:
|
if not force:
|
||||||
if not is_current:
|
if not is_current:
|
||||||
@@ -572,8 +754,21 @@ Examples:
|
|||||||
print(" ❌ Removal cancelled.")
|
print(" ❌ Removal cancelled.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if kind == "cli":
|
||||||
return self._delete_index_directory(
|
return self._delete_index_directory(
|
||||||
index_dir, index_name, project_path if not is_current else None
|
match["index_dir"],
|
||||||
|
index_name,
|
||||||
|
project_path if not is_current else None,
|
||||||
|
is_app=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._delete_index_directory(
|
||||||
|
match["files_dir"],
|
||||||
|
match.get("display_name", index_name),
|
||||||
|
project_path if not is_current else None,
|
||||||
|
is_app=True,
|
||||||
|
meta_file=match.get("meta_file"),
|
||||||
|
app_file_base=match.get("file_base"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _remove_from_multiple_matches(self, matches, index_name: str, force: bool):
|
def _remove_from_multiple_matches(self, matches, index_name: str, force: bool):
|
||||||
@@ -585,19 +780,34 @@ Examples:
|
|||||||
for i, match in enumerate(matches, 1):
|
for i, match in enumerate(matches, 1):
|
||||||
project_path = match["project_path"]
|
project_path = match["project_path"]
|
||||||
is_current = match["is_current"]
|
is_current = match["is_current"]
|
||||||
|
kind = match.get("kind", "cli")
|
||||||
|
|
||||||
if is_current:
|
if is_current:
|
||||||
print(f" {i}. 🏠 Current project")
|
print(f" {i}. 🏠 Current project ({'CLI' if kind == 'cli' else 'APP'})")
|
||||||
print(f" 📍 {project_path}")
|
|
||||||
else:
|
else:
|
||||||
print(f" {i}. 📂 {project_path.name}")
|
print(f" {i}. 📂 {project_path.name} ({'CLI' if kind == 'cli' else 'APP'})")
|
||||||
print(f" 📍 {project_path}")
|
|
||||||
|
# Show path details
|
||||||
|
if kind == "cli":
|
||||||
|
print(f" 📍 {project_path / '.leann' / 'indexes' / index_name}")
|
||||||
|
else:
|
||||||
|
print(f" 📍 {match['meta_file']}")
|
||||||
|
|
||||||
# Show size info
|
# Show size info
|
||||||
try:
|
try:
|
||||||
|
if kind == "cli":
|
||||||
size_mb = sum(
|
size_mb = sum(
|
||||||
f.stat().st_size for f in match["index_dir"].iterdir() if f.is_file()
|
f.stat().st_size for f in match["index_dir"].iterdir() if f.is_file()
|
||||||
) / (1024 * 1024)
|
) / (1024 * 1024)
|
||||||
|
else:
|
||||||
|
file_base = match.get("file_base")
|
||||||
|
size_mb = 0.0
|
||||||
|
if file_base:
|
||||||
|
size_mb = sum(
|
||||||
|
f.stat().st_size
|
||||||
|
for f in match["files_dir"].glob(f"{file_base}.leann*")
|
||||||
|
if f.is_file()
|
||||||
|
) / (1024 * 1024)
|
||||||
print(f" 📦 Size: {size_mb:.1f} MB")
|
print(f" 📦 Size: {size_mb:.1f} MB")
|
||||||
except (OSError, PermissionError):
|
except (OSError, PermissionError):
|
||||||
pass
|
pass
|
||||||
@@ -621,8 +831,8 @@ Examples:
|
|||||||
if 0 <= choice_idx < len(matches):
|
if 0 <= choice_idx < len(matches):
|
||||||
selected_match = matches[choice_idx]
|
selected_match = matches[choice_idx]
|
||||||
project_path = selected_match["project_path"]
|
project_path = selected_match["project_path"]
|
||||||
index_dir = selected_match["index_dir"]
|
|
||||||
is_current = selected_match["is_current"]
|
is_current = selected_match["is_current"]
|
||||||
|
kind = selected_match.get("kind", "cli")
|
||||||
|
|
||||||
location = "current project" if is_current else f"'{project_path.name}' project"
|
location = "current project" if is_current else f"'{project_path.name}' project"
|
||||||
print(f" 🎯 Selected: Remove from {location}")
|
print(f" 🎯 Selected: Remove from {location}")
|
||||||
@@ -635,8 +845,21 @@ Examples:
|
|||||||
print(" ❌ Confirmation failed. Removal cancelled.")
|
print(" ❌ Confirmation failed. Removal cancelled.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if kind == "cli":
|
||||||
return self._delete_index_directory(
|
return self._delete_index_directory(
|
||||||
index_dir, index_name, project_path if not is_current else None
|
selected_match["index_dir"],
|
||||||
|
index_name,
|
||||||
|
project_path if not is_current else None,
|
||||||
|
is_app=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self._delete_index_directory(
|
||||||
|
selected_match["files_dir"],
|
||||||
|
selected_match.get("display_name", index_name),
|
||||||
|
project_path if not is_current else None,
|
||||||
|
is_app=True,
|
||||||
|
meta_file=selected_match.get("meta_file"),
|
||||||
|
app_file_base=selected_match.get("file_base"),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(" ❌ Invalid choice. Removal cancelled.")
|
print(" ❌ Invalid choice. Removal cancelled.")
|
||||||
@@ -647,21 +870,65 @@ Examples:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def _delete_index_directory(
|
def _delete_index_directory(
|
||||||
self, index_dir: Path, index_name: str, project_path: Optional[Path] = None
|
self,
|
||||||
|
index_dir: Path,
|
||||||
|
index_display_name: str,
|
||||||
|
project_path: Optional[Path] = None,
|
||||||
|
is_app: bool = False,
|
||||||
|
meta_file: Optional[Path] = None,
|
||||||
|
app_file_base: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Actually delete the index directory"""
|
"""Delete a CLI index directory or APP index files safely."""
|
||||||
try:
|
try:
|
||||||
|
if is_app:
|
||||||
|
removed = 0
|
||||||
|
errors = 0
|
||||||
|
# Delete only files that belong to this app index (based on file base)
|
||||||
|
pattern_base = app_file_base or ""
|
||||||
|
for f in index_dir.glob(f"{pattern_base}.leann*"):
|
||||||
|
try:
|
||||||
|
f.unlink()
|
||||||
|
removed += 1
|
||||||
|
except Exception:
|
||||||
|
errors += 1
|
||||||
|
# Best-effort: also remove the meta file if specified and still exists
|
||||||
|
if meta_file and meta_file.exists():
|
||||||
|
try:
|
||||||
|
meta_file.unlink()
|
||||||
|
removed += 1
|
||||||
|
except Exception:
|
||||||
|
errors += 1
|
||||||
|
|
||||||
|
if removed > 0 and errors == 0:
|
||||||
|
if project_path:
|
||||||
|
print(
|
||||||
|
f"✅ App index '{index_display_name}' removed from {project_path.name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"✅ App index '{index_display_name}' removed successfully")
|
||||||
|
return True
|
||||||
|
elif removed > 0 and errors > 0:
|
||||||
|
print(
|
||||||
|
f"⚠️ App index '{index_display_name}' partially removed (some files couldn't be deleted)"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"❌ No files found to remove for app index '{index_display_name}' in {index_dir}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(index_dir)
|
shutil.rmtree(index_dir)
|
||||||
|
|
||||||
if project_path:
|
if project_path:
|
||||||
print(f"✅ Index '{index_name}' removed from {project_path.name}")
|
print(f"✅ Index '{index_display_name}' removed from {project_path.name}")
|
||||||
else:
|
else:
|
||||||
print(f"✅ Index '{index_name}' removed successfully")
|
print(f"✅ Index '{index_display_name}' removed successfully")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ Error removing index '{index_name}': {e}")
|
print(f"❌ Error removing index '{index_display_name}': {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_documents(
|
def load_documents(
|
||||||
@@ -669,6 +936,7 @@ Examples:
|
|||||||
docs_paths: Union[str, list],
|
docs_paths: Union[str, list],
|
||||||
custom_file_types: Union[str, None] = None,
|
custom_file_types: Union[str, None] = None,
|
||||||
include_hidden: bool = False,
|
include_hidden: bool = False,
|
||||||
|
args: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
# Handle both single path (string) and multiple paths (list) for backward compatibility
|
# Handle both single path (string) and multiple paths (list) for backward compatibility
|
||||||
if isinstance(docs_paths, str):
|
if isinstance(docs_paths, str):
|
||||||
@@ -833,7 +1101,8 @@ Examples:
|
|||||||
|
|
||||||
# Try to use better PDF parsers first, but only if PDFs are requested
|
# Try to use better PDF parsers first, but only if PDFs are requested
|
||||||
documents = []
|
documents = []
|
||||||
docs_path = Path(docs_dir)
|
# Use resolved absolute paths to avoid mismatches (symlinks, relative vs absolute)
|
||||||
|
docs_path = Path(docs_dir).resolve()
|
||||||
|
|
||||||
# Check if we should process PDFs
|
# Check if we should process PDFs
|
||||||
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
|
should_process_pdfs = custom_file_types is None or ".pdf" in custom_file_types
|
||||||
@@ -842,10 +1111,15 @@ Examples:
|
|||||||
for file_path in docs_path.rglob("*.pdf"):
|
for file_path in docs_path.rglob("*.pdf"):
|
||||||
# Check if file matches any exclude pattern
|
# Check if file matches any exclude pattern
|
||||||
try:
|
try:
|
||||||
|
# Ensure both paths are resolved before computing relativity
|
||||||
|
file_path_resolved = file_path.resolve()
|
||||||
|
# Determine directory scope using the non-resolved path to avoid
|
||||||
|
# misclassifying symlinked entries as outside the docs directory
|
||||||
relative_path = file_path.relative_to(docs_path)
|
relative_path = file_path.relative_to(docs_path)
|
||||||
if not include_hidden and _path_has_hidden_segment(relative_path):
|
if not include_hidden and _path_has_hidden_segment(relative_path):
|
||||||
continue
|
continue
|
||||||
if self._should_exclude_file(relative_path, gitignore_matches):
|
# Use absolute path for gitignore matching
|
||||||
|
if self._should_exclude_file(file_path_resolved, gitignore_matches):
|
||||||
continue
|
continue
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# Skip files that can't be made relative to docs_path
|
# Skip files that can't be made relative to docs_path
|
||||||
@@ -888,10 +1162,11 @@ Examples:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""Return True if file should be included (not excluded)"""
|
"""Return True if file should be included (not excluded)"""
|
||||||
try:
|
try:
|
||||||
docs_path_obj = Path(docs_dir)
|
docs_path_obj = Path(docs_dir).resolve()
|
||||||
file_path_obj = Path(file_path)
|
file_path_obj = Path(file_path).resolve()
|
||||||
relative_path = file_path_obj.relative_to(docs_path_obj)
|
# Use absolute path for gitignore matching
|
||||||
return not self._should_exclude_file(relative_path, gitignore_matches)
|
_ = file_path_obj.relative_to(docs_path_obj) # validate scope
|
||||||
|
return not self._should_exclude_file(file_path_obj, gitignore_matches)
|
||||||
except (ValueError, OSError):
|
except (ValueError, OSError):
|
||||||
return True # Include files that can't be processed
|
return True # Include files that can't be processed
|
||||||
|
|
||||||
@@ -974,7 +1249,36 @@ Examples:
|
|||||||
}
|
}
|
||||||
|
|
||||||
print("start chunking documents")
|
print("start chunking documents")
|
||||||
# Add progress bar for document chunking
|
|
||||||
|
# Check if AST chunking is requested
|
||||||
|
use_ast = getattr(args, "use_ast_chunking", False)
|
||||||
|
|
||||||
|
if use_ast:
|
||||||
|
print("🧠 Using AST-aware chunking for code files")
|
||||||
|
try:
|
||||||
|
# Import enhanced chunking utilities from packaged module
|
||||||
|
from .chunking_utils import create_text_chunks
|
||||||
|
|
||||||
|
# Use enhanced chunking with AST support
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=self.node_parser.chunk_size,
|
||||||
|
chunk_overlap=self.node_parser.chunk_overlap,
|
||||||
|
use_ast_chunking=True,
|
||||||
|
ast_chunk_size=getattr(args, "ast_chunk_size", 768),
|
||||||
|
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 96),
|
||||||
|
code_file_extensions=None, # Use defaults
|
||||||
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(
|
||||||
|
f"⚠️ AST chunking utilities not available in package ({e}), falling back to traditional chunking"
|
||||||
|
)
|
||||||
|
use_ast = False
|
||||||
|
|
||||||
|
if not use_ast:
|
||||||
|
# Use traditional chunking logic
|
||||||
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
for doc in tqdm(documents, desc="Chunking documents", unit="doc"):
|
||||||
# Check if this is a code file based on source path
|
# Check if this is a code file based on source path
|
||||||
source_path = doc.metadata.get("source", "")
|
source_path = doc.metadata.get("source", "")
|
||||||
@@ -1052,7 +1356,7 @@ Examples:
|
|||||||
)
|
)
|
||||||
|
|
||||||
all_texts = self.load_documents(
|
all_texts = self.load_documents(
|
||||||
docs_paths, args.file_types, include_hidden=args.include_hidden
|
docs_paths, args.file_types, include_hidden=args.include_hidden, args=args
|
||||||
)
|
)
|
||||||
if not all_texts:
|
if not all_texts:
|
||||||
print("No documents found")
|
print("No documents found")
|
||||||
@@ -1062,10 +1366,20 @@ Examples:
|
|||||||
|
|
||||||
print(f"Building index '{index_name}' with {args.backend} backend...")
|
print(f"Building index '{index_name}' with {args.backend} backend...")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
builder = LeannBuilder(
|
builder = LeannBuilder(
|
||||||
backend_name=args.backend,
|
backend_name=args.backend,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
embedding_mode=args.embedding_mode,
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
graph_degree=args.graph_degree,
|
graph_degree=args.graph_degree,
|
||||||
complexity=args.complexity,
|
complexity=args.complexity,
|
||||||
is_compact=args.compact,
|
is_compact=args.compact,
|
||||||
@@ -1085,13 +1399,101 @@ Examples:
|
|||||||
async def search_documents(self, args):
|
async def search_documents(self, args):
|
||||||
index_name = args.index_name
|
index_name = args.index_name
|
||||||
query = args.query
|
query = args.query
|
||||||
index_path = self.get_index_path(index_name)
|
|
||||||
|
|
||||||
if not self.index_exists(index_name):
|
# First try to find the index in current project
|
||||||
|
index_path = self.get_index_path(index_name)
|
||||||
|
if self.index_exists(index_name):
|
||||||
|
# Found in current project, use it
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Search across all registered projects (like list_indexes does)
|
||||||
|
all_matches = self._find_all_matching_indexes(index_name)
|
||||||
|
if not all_matches:
|
||||||
print(
|
print(
|
||||||
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it."
|
f"Index '{index_name}' not found. Use 'leann build {index_name} --docs <dir> [<dir2> ...]' to create it."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
elif len(all_matches) == 1:
|
||||||
|
# Found exactly one match, use it
|
||||||
|
match = all_matches[0]
|
||||||
|
if match["kind"] == "cli":
|
||||||
|
index_path = str(match["index_dir"] / "documents.leann")
|
||||||
|
else:
|
||||||
|
# App format: use the meta file to construct the path
|
||||||
|
meta_file = match["meta_file"]
|
||||||
|
file_base = match["file_base"]
|
||||||
|
index_path = str(meta_file.parent / f"{file_base}.leann")
|
||||||
|
|
||||||
|
project_info = (
|
||||||
|
"current project"
|
||||||
|
if match["is_current"]
|
||||||
|
else f"project '{match['project_path'].name}'"
|
||||||
|
)
|
||||||
|
print(f"Using index '{index_name}' from {project_info}")
|
||||||
|
else:
|
||||||
|
# Multiple matches found
|
||||||
|
if args.non_interactive:
|
||||||
|
# Non-interactive mode: automatically select the best match
|
||||||
|
# Priority: current project first, then first available
|
||||||
|
current_matches = [m for m in all_matches if m["is_current"]]
|
||||||
|
if current_matches:
|
||||||
|
match = current_matches[0]
|
||||||
|
location_desc = "current project"
|
||||||
|
else:
|
||||||
|
match = all_matches[0]
|
||||||
|
location_desc = f"project '{match['project_path'].name}'"
|
||||||
|
|
||||||
|
if match["kind"] == "cli":
|
||||||
|
index_path = str(match["index_dir"] / "documents.leann")
|
||||||
|
else:
|
||||||
|
meta_file = match["meta_file"]
|
||||||
|
file_base = match["file_base"]
|
||||||
|
index_path = str(meta_file.parent / f"{file_base}.leann")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Found {len(all_matches)} indexes named '{index_name}', using index from {location_desc}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Interactive mode: ask user to choose
|
||||||
|
print(f"Found {len(all_matches)} indexes named '{index_name}':")
|
||||||
|
for i, match in enumerate(all_matches, 1):
|
||||||
|
project_path = match["project_path"]
|
||||||
|
is_current = match["is_current"]
|
||||||
|
kind = match.get("kind", "cli")
|
||||||
|
|
||||||
|
if is_current:
|
||||||
|
print(
|
||||||
|
f" {i}. 🏠 Current project ({'CLI' if kind == 'cli' else 'APP'})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f" {i}. 📂 {project_path.name} ({'CLI' if kind == 'cli' else 'APP'})"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
choice = input(f"Which index to search? (1-{len(all_matches)}): ").strip()
|
||||||
|
choice_idx = int(choice) - 1
|
||||||
|
if 0 <= choice_idx < len(all_matches):
|
||||||
|
match = all_matches[choice_idx]
|
||||||
|
if match["kind"] == "cli":
|
||||||
|
index_path = str(match["index_dir"] / "documents.leann")
|
||||||
|
else:
|
||||||
|
meta_file = match["meta_file"]
|
||||||
|
file_base = match["file_base"]
|
||||||
|
index_path = str(meta_file.parent / f"{file_base}.leann")
|
||||||
|
|
||||||
|
project_info = (
|
||||||
|
"current project"
|
||||||
|
if match["is_current"]
|
||||||
|
else f"project '{match['project_path'].name}'"
|
||||||
|
)
|
||||||
|
print(f"Using index '{index_name}' from {project_info}")
|
||||||
|
else:
|
||||||
|
print("Invalid choice. Aborting search.")
|
||||||
|
return
|
||||||
|
except (ValueError, KeyboardInterrupt):
|
||||||
|
print("Invalid input. Aborting search.")
|
||||||
|
return
|
||||||
|
|
||||||
searcher = LeannSearcher(index_path=index_path)
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
results = searcher.search(
|
results = searcher.search(
|
||||||
@@ -1125,11 +1527,38 @@ Examples:
|
|||||||
|
|
||||||
llm_config = {"type": args.llm, "model": args.model}
|
llm_config = {"type": args.llm, "model": args.model}
|
||||||
if args.llm == "ollama":
|
if args.llm == "ollama":
|
||||||
llm_config["host"] = args.host
|
llm_config["host"] = resolve_ollama_host(args.host)
|
||||||
|
elif args.llm == "openai":
|
||||||
|
llm_config["base_url"] = resolve_openai_base_url(args.api_base)
|
||||||
|
resolved_api_key = resolve_openai_api_key(args.api_key)
|
||||||
|
if resolved_api_key:
|
||||||
|
llm_config["api_key"] = resolved_api_key
|
||||||
|
|
||||||
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
chat = LeannChat(index_path=index_path, llm_config=llm_config)
|
||||||
|
|
||||||
|
llm_kwargs: dict[str, Any] = {}
|
||||||
|
if args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
def _ask_once(prompt: str) -> None:
|
||||||
|
response = chat.ask(
|
||||||
|
prompt,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=args.prune_ratio,
|
||||||
|
recompute_embeddings=args.recompute_embeddings,
|
||||||
|
pruning_strategy=args.pruning_strategy,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
)
|
||||||
|
print(f"LEANN: {response}")
|
||||||
|
|
||||||
|
initial_query = (args.query or "").strip()
|
||||||
|
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
|
if initial_query:
|
||||||
|
_ask_once(initial_query)
|
||||||
|
|
||||||
print("LEANN Assistant ready! Type 'quit' to exit")
|
print("LEANN Assistant ready! Type 'quit' to exit")
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
|
|
||||||
@@ -1142,41 +1571,14 @@ Examples:
|
|||||||
if not user_input:
|
if not user_input:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
_ask_once(user_input)
|
||||||
llm_kwargs = {}
|
|
||||||
if args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
|
||||||
user_input,
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.complexity,
|
|
||||||
beam_width=args.beam_width,
|
|
||||||
prune_ratio=args.prune_ratio,
|
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
|
||||||
pruning_strategy=args.pruning_strategy,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
|
||||||
print(f"LEANN: {response}")
|
|
||||||
else:
|
else:
|
||||||
query = input("Enter your question: ").strip()
|
query = initial_query or input("Enter your question: ").strip()
|
||||||
if query:
|
if not query:
|
||||||
# Prepare LLM kwargs with thinking budget if specified
|
print("No question provided. Exiting.")
|
||||||
llm_kwargs = {}
|
return
|
||||||
if args.thinking_budget:
|
|
||||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
|
||||||
|
|
||||||
response = chat.ask(
|
_ask_once(query)
|
||||||
query,
|
|
||||||
top_k=args.top_k,
|
|
||||||
complexity=args.complexity,
|
|
||||||
beam_width=args.beam_width,
|
|
||||||
prune_ratio=args.prune_ratio,
|
|
||||||
recompute_embeddings=args.recompute_embeddings,
|
|
||||||
pruning_strategy=args.pruning_strategy,
|
|
||||||
llm_kwargs=llm_kwargs,
|
|
||||||
)
|
|
||||||
print(f"LEANN: {response}")
|
|
||||||
|
|
||||||
async def run(self, args=None):
|
async def run(self, args=None):
|
||||||
parser = self.create_parser()
|
parser = self.create_parser()
|
||||||
|
|||||||
@@ -6,11 +6,14 @@ Preserves all optimization parameters to ensure performance
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
# Set up logger with proper level
|
# Set up logger with proper level
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
LOG_LEVEL = os.getenv("LEANN_LOG_LEVEL", "WARNING").upper()
|
||||||
@@ -28,6 +31,9 @@ def compute_embeddings(
|
|||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
|
provider_options: Optional[dict[str, Any]] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Unified embedding computation entry point
|
Unified embedding computation entry point
|
||||||
@@ -43,6 +49,8 @@ def compute_embeddings(
|
|||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
"""
|
"""
|
||||||
|
provider_options = provider_options or {}
|
||||||
|
|
||||||
if mode == "sentence-transformers":
|
if mode == "sentence-transformers":
|
||||||
return compute_embeddings_sentence_transformers(
|
return compute_embeddings_sentence_transformers(
|
||||||
texts,
|
texts,
|
||||||
@@ -50,13 +58,25 @@ def compute_embeddings(
|
|||||||
is_build=is_build,
|
is_build=is_build,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
adaptive_optimization=adaptive_optimization,
|
adaptive_optimization=adaptive_optimization,
|
||||||
|
manual_tokenize=manual_tokenize,
|
||||||
|
max_length=max_length,
|
||||||
)
|
)
|
||||||
elif mode == "openai":
|
elif mode == "openai":
|
||||||
return compute_embeddings_openai(texts, model_name)
|
return compute_embeddings_openai(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
base_url=provider_options.get("base_url"),
|
||||||
|
api_key=provider_options.get("api_key"),
|
||||||
|
)
|
||||||
elif mode == "mlx":
|
elif mode == "mlx":
|
||||||
return compute_embeddings_mlx(texts, model_name)
|
return compute_embeddings_mlx(texts, model_name)
|
||||||
elif mode == "ollama":
|
elif mode == "ollama":
|
||||||
return compute_embeddings_ollama(texts, model_name, is_build=is_build)
|
return compute_embeddings_ollama(
|
||||||
|
texts,
|
||||||
|
model_name,
|
||||||
|
is_build=is_build,
|
||||||
|
host=provider_options.get("host"),
|
||||||
|
)
|
||||||
elif mode == "gemini":
|
elif mode == "gemini":
|
||||||
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
return compute_embeddings_gemini(texts, model_name, is_build=is_build)
|
||||||
else:
|
else:
|
||||||
@@ -71,6 +91,8 @@ def compute_embeddings_sentence_transformers(
|
|||||||
batch_size: int = 32,
|
batch_size: int = 32,
|
||||||
is_build: bool = False,
|
is_build: bool = False,
|
||||||
adaptive_optimization: bool = True,
|
adaptive_optimization: bool = True,
|
||||||
|
manual_tokenize: bool = False,
|
||||||
|
max_length: int = 512,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
Compute embeddings using SentenceTransformer with model caching and adaptive optimization
|
||||||
@@ -214,9 +236,13 @@ def compute_embeddings_sentence_transformers(
|
|||||||
logger.info(f"Model cached: {cache_key}")
|
logger.info(f"Model cached: {cache_key}")
|
||||||
|
|
||||||
# Compute embeddings with optimized inference mode
|
# Compute embeddings with optimized inference mode
|
||||||
logger.info(f"Starting embedding computation... (batch_size: {batch_size})")
|
logger.info(
|
||||||
|
f"Starting embedding computation... (batch_size: {batch_size}, manual_tokenize={manual_tokenize})"
|
||||||
|
)
|
||||||
|
|
||||||
# Use torch.inference_mode for optimal performance
|
start_time = time.time()
|
||||||
|
if not manual_tokenize:
|
||||||
|
# Use SentenceTransformer's optimized encode path (default)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
embeddings = model.encode(
|
embeddings = model.encode(
|
||||||
texts,
|
texts,
|
||||||
@@ -226,8 +252,114 @@ def compute_embeddings_sentence_transformers(
|
|||||||
normalize_embeddings=False,
|
normalize_embeddings=False,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
# Synchronize if CUDA to measure accurate wall time
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Manual tokenization + forward pass using HF AutoTokenizer/AutoModel
|
||||||
|
try:
|
||||||
|
from transformers import AutoModel, AutoTokenizer # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise ImportError(f"transformers is required for manual_tokenize=True: {e}")
|
||||||
|
|
||||||
|
# Cache tokenizer and model
|
||||||
|
tok_cache_key = f"hf_tokenizer_{model_name}"
|
||||||
|
mdl_cache_key = f"hf_model_{model_name}_{device}_{use_fp16}"
|
||||||
|
if tok_cache_key in _model_cache and mdl_cache_key in _model_cache:
|
||||||
|
hf_tokenizer = _model_cache[tok_cache_key]
|
||||||
|
hf_model = _model_cache[mdl_cache_key]
|
||||||
|
logger.info("Using cached HF tokenizer/model for manual path")
|
||||||
|
else:
|
||||||
|
logger.info("Loading HF tokenizer/model for manual tokenization path")
|
||||||
|
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
||||||
|
torch_dtype = torch.float16 if (use_fp16 and device == "cuda") else torch.float32
|
||||||
|
hf_model = AutoModel.from_pretrained(model_name, torch_dtype=torch_dtype)
|
||||||
|
hf_model.to(device)
|
||||||
|
hf_model.eval()
|
||||||
|
# Optional compile on supported devices
|
||||||
|
if device in ["cuda", "mps"]:
|
||||||
|
try:
|
||||||
|
hf_model = torch.compile(hf_model, mode="reduce-overhead", dynamic=True) # type: ignore
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
_model_cache[tok_cache_key] = hf_tokenizer
|
||||||
|
_model_cache[mdl_cache_key] = hf_model
|
||||||
|
|
||||||
|
all_embeddings: list[np.ndarray] = []
|
||||||
|
# Progress bar when building or for large inputs
|
||||||
|
show_progress = is_build or len(texts) > 32
|
||||||
|
try:
|
||||||
|
if show_progress:
|
||||||
|
from tqdm import tqdm # type: ignore
|
||||||
|
|
||||||
|
batch_iter = tqdm(
|
||||||
|
range(0, len(texts), batch_size),
|
||||||
|
desc="Embedding (manual)",
|
||||||
|
unit="batch",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
except Exception:
|
||||||
|
batch_iter = range(0, len(texts), batch_size)
|
||||||
|
|
||||||
|
start_time_manual = time.time()
|
||||||
|
with torch.inference_mode():
|
||||||
|
for start_index in batch_iter:
|
||||||
|
end_index = min(start_index + batch_size, len(texts))
|
||||||
|
batch_texts = texts[start_index:end_index]
|
||||||
|
tokenize_start_time = time.time()
|
||||||
|
inputs = hf_tokenizer(
|
||||||
|
batch_texts,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
tokenize_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"Tokenize time taken: {tokenize_end_time - tokenize_start_time} seconds"
|
||||||
|
)
|
||||||
|
# Print shapes of all input tensors for debugging
|
||||||
|
for k, v in inputs.items():
|
||||||
|
print(f"inputs[{k!r}] shape: {getattr(v, 'shape', type(v))}")
|
||||||
|
to_device_start_time = time.time()
|
||||||
|
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||||
|
to_device_end_time = time.time()
|
||||||
|
logger.info(
|
||||||
|
f"To device time taken: {to_device_end_time - to_device_start_time} seconds"
|
||||||
|
)
|
||||||
|
forward_start_time = time.time()
|
||||||
|
outputs = hf_model(**inputs)
|
||||||
|
forward_end_time = time.time()
|
||||||
|
logger.info(f"Forward time taken: {forward_end_time - forward_start_time} seconds")
|
||||||
|
last_hidden_state = outputs.last_hidden_state # (B, L, H)
|
||||||
|
attention_mask = inputs.get("attention_mask")
|
||||||
|
if attention_mask is None:
|
||||||
|
# Fallback: assume all tokens are valid
|
||||||
|
pooled = last_hidden_state.mean(dim=1)
|
||||||
|
else:
|
||||||
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
|
||||||
|
masked = last_hidden_state * mask
|
||||||
|
lengths = mask.sum(dim=1).clamp(min=1)
|
||||||
|
pooled = masked.sum(dim=1) / lengths
|
||||||
|
# Move to CPU float32
|
||||||
|
batch_embeddings = pooled.detach().to("cpu").float().numpy()
|
||||||
|
all_embeddings.append(batch_embeddings)
|
||||||
|
|
||||||
|
embeddings = np.vstack(all_embeddings).astype(np.float32, copy=False)
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
end_time = time.time()
|
||||||
|
logger.info(f"Manual tokenize time taken: {end_time - start_time_manual} seconds")
|
||||||
|
end_time = time.time()
|
||||||
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
logger.info(f"Generated {len(embeddings)} embeddings, dimension: {embeddings.shape[1]}")
|
||||||
|
logger.info(f"Time taken: {end_time - start_time} seconds")
|
||||||
|
|
||||||
# Validate results
|
# Validate results
|
||||||
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
||||||
@@ -236,12 +368,15 @@ def compute_embeddings_sentence_transformers(
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
def compute_embeddings_openai(
|
||||||
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
# TODO: @yichuan-w add progress bar only in build mode
|
# TODO: @yichuan-w add progress bar only in build mode
|
||||||
"""Compute embeddings using OpenAI API"""
|
"""Compute embeddings using OpenAI API"""
|
||||||
try:
|
try:
|
||||||
import os
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(f"OpenAI package not installed: {e}")
|
raise ImportError(f"OpenAI package not installed: {e}")
|
||||||
@@ -256,16 +391,18 @@ def compute_embeddings_openai(texts: list[str], model_name: str) -> np.ndarray:
|
|||||||
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI."
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
resolved_base_url = resolve_openai_base_url(base_url)
|
||||||
if not api_key:
|
resolved_api_key = resolve_openai_api_key(api_key)
|
||||||
|
|
||||||
|
if not resolved_api_key:
|
||||||
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
raise RuntimeError("OPENAI_API_KEY environment variable not set")
|
||||||
|
|
||||||
# Cache OpenAI client
|
# Cache OpenAI client
|
||||||
cache_key = "openai_client"
|
cache_key = f"openai_client::{resolved_base_url}"
|
||||||
if cache_key in _model_cache:
|
if cache_key in _model_cache:
|
||||||
client = _model_cache[cache_key]
|
client = _model_cache[cache_key]
|
||||||
else:
|
else:
|
||||||
client = openai.OpenAI(api_key=api_key)
|
client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url)
|
||||||
_model_cache[cache_key] = client
|
_model_cache[cache_key] = client
|
||||||
logger.info("OpenAI client cached")
|
logger.info("OpenAI client cached")
|
||||||
|
|
||||||
@@ -390,7 +527,10 @@ def compute_embeddings_mlx(chunks: list[str], model_name: str, batch_size: int =
|
|||||||
|
|
||||||
|
|
||||||
def compute_embeddings_ollama(
|
def compute_embeddings_ollama(
|
||||||
texts: list[str], model_name: str, is_build: bool = False, host: str = "http://localhost:11434"
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
is_build: bool = False,
|
||||||
|
host: Optional[str] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute embeddings using Ollama API with simplified batch processing.
|
Compute embeddings using Ollama API with simplified batch processing.
|
||||||
@@ -401,7 +541,7 @@ def compute_embeddings_ollama(
|
|||||||
texts: List of texts to compute embeddings for
|
texts: List of texts to compute embeddings for
|
||||||
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large")
|
||||||
is_build: Whether this is a build operation (shows progress bar)
|
is_build: Whether this is a build operation (shows progress bar)
|
||||||
host: Ollama host URL (default: http://localhost:11434)
|
host: Ollama host URL (defaults to environment or http://localhost:11434)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
Normalized embeddings array, shape: (len(texts), embedding_dim)
|
||||||
@@ -416,17 +556,19 @@ def compute_embeddings_ollama(
|
|||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("Cannot compute embeddings for empty text list")
|
raise ValueError("Cannot compute embeddings for empty text list")
|
||||||
|
|
||||||
|
resolved_host = resolve_ollama_host(host)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}'"
|
f"Computing embeddings for {len(texts)} texts using Ollama API, model: '{model_name}', host: '{resolved_host}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if Ollama is running
|
# Check if Ollama is running
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{host}/api/version", timeout=5)
|
response = requests.get(f"{resolved_host}/api/version", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"❌ Could not connect to Ollama at {host}.\n\n"
|
f"❌ Could not connect to Ollama at {resolved_host}.\n\n"
|
||||||
"Please ensure Ollama is running:\n"
|
"Please ensure Ollama is running:\n"
|
||||||
" • macOS/Linux: ollama serve\n"
|
" • macOS/Linux: ollama serve\n"
|
||||||
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
" • Windows: Make sure Ollama is running in the system tray\n\n"
|
||||||
@@ -438,7 +580,7 @@ def compute_embeddings_ollama(
|
|||||||
|
|
||||||
# Check if model exists and provide helpful suggestions
|
# Check if model exists and provide helpful suggestions
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{host}/api/tags", timeout=5)
|
response = requests.get(f"{resolved_host}/api/tags", timeout=5)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
models = response.json()
|
models = response.json()
|
||||||
model_names = [model["name"] for model in models.get("models", [])]
|
model_names = [model["name"] for model in models.get("models", [])]
|
||||||
@@ -501,7 +643,9 @@ def compute_embeddings_ollama(
|
|||||||
# Verify the model supports embeddings by testing it
|
# Verify the model supports embeddings by testing it
|
||||||
try:
|
try:
|
||||||
test_response = requests.post(
|
test_response = requests.post(
|
||||||
f"{host}/api/embeddings", json={"model": model_name, "prompt": "test"}, timeout=10
|
f"{resolved_host}/api/embeddings",
|
||||||
|
json={"model": model_name, "prompt": "test"},
|
||||||
|
timeout=10,
|
||||||
)
|
)
|
||||||
if test_response.status_code != 200:
|
if test_response.status_code != 200:
|
||||||
error_msg = (
|
error_msg = (
|
||||||
@@ -548,7 +692,7 @@ def compute_embeddings_ollama(
|
|||||||
while retry_count < max_retries:
|
while retry_count < max_retries:
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{host}/api/embeddings",
|
f"{resolved_host}/api/embeddings",
|
||||||
json={"model": model_name, "prompt": truncated_text},
|
json={"model": model_name, "prompt": truncated_text},
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from .settings import encode_provider_options
|
||||||
|
|
||||||
# Lightweight, self-contained server manager with no cross-process inspection
|
# Lightweight, self-contained server manager with no cross-process inspection
|
||||||
|
|
||||||
# Set up logging based on environment variable
|
# Set up logging based on environment variable
|
||||||
@@ -82,16 +84,40 @@ class EmbeddingServerManager:
|
|||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start the embedding server."""
|
"""Start the embedding server."""
|
||||||
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
# passages_file may be present in kwargs for server CLI, but we don't need it here
|
||||||
|
provider_options = kwargs.pop("provider_options", None)
|
||||||
|
|
||||||
|
config_signature = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
|
||||||
# If this manager already has a live server, just reuse it
|
# If this manager already has a live server, just reuse it
|
||||||
if self.server_process and self.server_process.poll() is None and self.server_port:
|
if (
|
||||||
|
self.server_process
|
||||||
|
and self.server_process.poll() is None
|
||||||
|
and self.server_port
|
||||||
|
and self._server_config == config_signature
|
||||||
|
):
|
||||||
logger.info("Reusing in-process server")
|
logger.info("Reusing in-process server")
|
||||||
return True, self.server_port
|
return True, self.server_port
|
||||||
|
|
||||||
|
# Configuration changed, stop existing server before starting a new one
|
||||||
|
if self.server_process and self.server_process.poll() is None:
|
||||||
|
logger.info("Existing server configuration differs; restarting embedding server")
|
||||||
|
self.stop_server()
|
||||||
|
|
||||||
# For Colab environment, use a different strategy
|
# For Colab environment, use a different strategy
|
||||||
if _is_colab_environment():
|
if _is_colab_environment():
|
||||||
logger.info("Detected Colab environment, using alternative startup strategy")
|
logger.info("Detected Colab environment, using alternative startup strategy")
|
||||||
return self._start_server_colab(port, model_name, embedding_mode, **kwargs)
|
return self._start_server_colab(
|
||||||
|
port,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=provider_options,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Always pick a fresh available port
|
# Always pick a fresh available port
|
||||||
try:
|
try:
|
||||||
@@ -101,13 +127,21 @@ class EmbeddingServerManager:
|
|||||||
return False, port
|
return False, port
|
||||||
|
|
||||||
# Start a new server
|
# Start a new server
|
||||||
return self._start_new_server(actual_port, model_name, embedding_mode, **kwargs)
|
return self._start_new_server(
|
||||||
|
actual_port,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=provider_options,
|
||||||
|
config_signature=config_signature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def _start_server_colab(
|
def _start_server_colab(
|
||||||
self,
|
self,
|
||||||
port: int,
|
port: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
embedding_mode: str = "sentence-transformers",
|
embedding_mode: str = "sentence-transformers",
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start server with Colab-specific configuration."""
|
"""Start server with Colab-specific configuration."""
|
||||||
@@ -125,8 +159,20 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# In Colab, we'll use a more direct approach
|
# In Colab, we'll use a more direct approach
|
||||||
self._launch_server_process_colab(command, actual_port)
|
self._launch_server_process_colab(
|
||||||
return self._wait_for_server_ready_colab(actual_port)
|
command,
|
||||||
|
actual_port,
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
started, ready_port = self._wait_for_server_ready_colab(actual_port)
|
||||||
|
if started:
|
||||||
|
self._server_config = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
return started, ready_port
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start embedding server in Colab: {e}")
|
logger.error(f"Failed to start embedding server in Colab: {e}")
|
||||||
return False, actual_port
|
return False, actual_port
|
||||||
@@ -134,7 +180,13 @@ class EmbeddingServerManager:
|
|||||||
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
# Note: No compatibility check needed; manager is per-searcher and configs are stable per instance
|
||||||
|
|
||||||
def _start_new_server(
|
def _start_new_server(
|
||||||
self, port: int, model_name: str, embedding_mode: str, **kwargs
|
self,
|
||||||
|
port: int,
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
config_signature: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Start a new embedding server on the given port."""
|
"""Start a new embedding server on the given port."""
|
||||||
logger.info(f"Starting embedding server on port {port}...")
|
logger.info(f"Starting embedding server on port {port}...")
|
||||||
@@ -142,8 +194,20 @@ class EmbeddingServerManager:
|
|||||||
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
command = self._build_server_command(port, model_name, embedding_mode, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._launch_server_process(command, port)
|
self._launch_server_process(
|
||||||
return self._wait_for_server_ready(port)
|
command,
|
||||||
|
port,
|
||||||
|
provider_options=provider_options,
|
||||||
|
)
|
||||||
|
started, ready_port = self._wait_for_server_ready(port)
|
||||||
|
if started:
|
||||||
|
self._server_config = config_signature or {
|
||||||
|
"model_name": model_name,
|
||||||
|
"passages_file": kwargs.get("passages_file", ""),
|
||||||
|
"embedding_mode": embedding_mode,
|
||||||
|
"provider_options": provider_options or {},
|
||||||
|
}
|
||||||
|
return started, ready_port
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start embedding server: {e}")
|
logger.error(f"Failed to start embedding server: {e}")
|
||||||
return False, port
|
return False, port
|
||||||
@@ -173,7 +237,12 @@ class EmbeddingServerManager:
|
|||||||
|
|
||||||
return command
|
return command
|
||||||
|
|
||||||
def _launch_server_process(self, command: list, port: int) -> None:
|
def _launch_server_process(
|
||||||
|
self,
|
||||||
|
command: list,
|
||||||
|
port: int,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
"""Launch the server process."""
|
"""Launch the server process."""
|
||||||
project_root = Path(__file__).parent.parent.parent.parent.parent
|
project_root = Path(__file__).parent.parent.parent.parent.parent
|
||||||
logger.info(f"Command: {' '.join(command)}")
|
logger.info(f"Command: {' '.join(command)}")
|
||||||
@@ -192,14 +261,21 @@ class EmbeddingServerManager:
|
|||||||
stderr_target = None # Direct to console for visible logs
|
stderr_target = None # Direct to console for visible logs
|
||||||
|
|
||||||
# Start embedding server subprocess
|
# Start embedding server subprocess
|
||||||
|
logger.info(f"Starting server process with command: {' '.join(command)}")
|
||||||
|
env = os.environ.copy()
|
||||||
|
encoded_options = encode_provider_options(provider_options)
|
||||||
|
if encoded_options:
|
||||||
|
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
cwd=project_root,
|
cwd=project_root,
|
||||||
stdout=stdout_target,
|
stdout=stdout_target,
|
||||||
stderr=stderr_target,
|
stderr=stderr_target,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
# Record config for in-process reuse
|
# Record config for in-process reuse (best effort; refined later when ready)
|
||||||
try:
|
try:
|
||||||
self._server_config = {
|
self._server_config = {
|
||||||
"model_name": command[command.index("--model-name") + 1]
|
"model_name": command[command.index("--model-name") + 1]
|
||||||
@@ -211,12 +287,14 @@ class EmbeddingServerManager:
|
|||||||
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
"embedding_mode": command[command.index("--embedding-mode") + 1]
|
||||||
if "--embedding-mode" in command
|
if "--embedding-mode" in command
|
||||||
else "sentence-transformers",
|
else "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
self._server_config = {
|
self._server_config = {
|
||||||
"model_name": "",
|
"model_name": "",
|
||||||
"passages_file": "",
|
"passages_file": "",
|
||||||
"embedding_mode": "sentence-transformers",
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
logger.info(f"Server process started with PID: {self.server_process.pid}")
|
||||||
|
|
||||||
@@ -321,16 +399,27 @@ class EmbeddingServerManager:
|
|||||||
# Removed: cross-process adoption no longer supported
|
# Removed: cross-process adoption no longer supported
|
||||||
return
|
return
|
||||||
|
|
||||||
def _launch_server_process_colab(self, command: list, port: int) -> None:
|
def _launch_server_process_colab(
|
||||||
|
self,
|
||||||
|
command: list,
|
||||||
|
port: int,
|
||||||
|
provider_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
"""Launch the server process with Colab-specific settings."""
|
"""Launch the server process with Colab-specific settings."""
|
||||||
logger.info(f"Colab Command: {' '.join(command)}")
|
logger.info(f"Colab Command: {' '.join(command)}")
|
||||||
|
|
||||||
# In Colab, we need to be more careful about process management
|
# In Colab, we need to be more careful about process management
|
||||||
|
env = os.environ.copy()
|
||||||
|
encoded_options = encode_provider_options(provider_options)
|
||||||
|
if encoded_options:
|
||||||
|
env["LEANN_EMBEDDING_OPTIONS"] = encoded_options
|
||||||
|
|
||||||
self.server_process = subprocess.Popen(
|
self.server_process = subprocess.Popen(
|
||||||
command,
|
command,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
stderr=subprocess.PIPE,
|
stderr=subprocess.PIPE,
|
||||||
text=True,
|
text=True,
|
||||||
|
env=env,
|
||||||
)
|
)
|
||||||
self.server_port = port
|
self.server_port = port
|
||||||
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
logger.info(f"Colab server process started with PID: {self.server_process.pid}")
|
||||||
@@ -344,6 +433,7 @@ class EmbeddingServerManager:
|
|||||||
"model_name": "",
|
"model_name": "",
|
||||||
"passages_file": "",
|
"passages_file": "",
|
||||||
"embedding_mode": "sentence-transformers",
|
"embedding_mode": "sentence-transformers",
|
||||||
|
"provider_options": provider_options or {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
def _wait_for_server_ready_colab(self, port: int) -> tuple[bool, int]:
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ def handle_request(request):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build simplified command
|
# Build simplified command with non-interactive flag for MCP compatibility
|
||||||
cmd = [
|
cmd = [
|
||||||
"leann",
|
"leann",
|
||||||
"search",
|
"search",
|
||||||
@@ -102,6 +102,7 @@ def handle_request(request):
|
|||||||
args["query"],
|
args["query"],
|
||||||
f"--top-k={args.get('top_k', 5)}",
|
f"--top-k={args.get('top_k', 5)}",
|
||||||
f"--complexity={args.get('complexity', 32)}",
|
f"--complexity={args.get('complexity', 32)}",
|
||||||
|
"--non-interactive",
|
||||||
]
|
]
|
||||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
|||||||
240
packages/leann-core/src/leann/metadata_filter.py
Normal file
240
packages/leann-core/src/leann/metadata_filter.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""
|
||||||
|
Metadata filtering engine for LEANN search results.
|
||||||
|
|
||||||
|
This module provides generic metadata filtering capabilities that can be applied
|
||||||
|
to search results from any LEANN backend. The filtering supports various
|
||||||
|
operators for different data types including numbers, strings, booleans, and lists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type alias for filter specifications
|
||||||
|
FilterValue = Union[str, int, float, bool, list]
|
||||||
|
FilterSpec = dict[str, FilterValue]
|
||||||
|
MetadataFilters = dict[str, FilterSpec]
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataFilterEngine:
|
||||||
|
"""
|
||||||
|
Engine for evaluating metadata filters against search results.
|
||||||
|
|
||||||
|
Supports various operators for filtering based on metadata fields:
|
||||||
|
- Comparison: ==, !=, <, <=, >, >=
|
||||||
|
- Membership: in, not_in
|
||||||
|
- String operations: contains, starts_with, ends_with
|
||||||
|
- Boolean operations: is_true, is_false
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the filter engine with supported operators."""
|
||||||
|
self.operators = {
|
||||||
|
"==": self._equals,
|
||||||
|
"!=": self._not_equals,
|
||||||
|
"<": self._less_than,
|
||||||
|
"<=": self._less_than_or_equal,
|
||||||
|
">": self._greater_than,
|
||||||
|
">=": self._greater_than_or_equal,
|
||||||
|
"in": self._in,
|
||||||
|
"not_in": self._not_in,
|
||||||
|
"contains": self._contains,
|
||||||
|
"starts_with": self._starts_with,
|
||||||
|
"ends_with": self._ends_with,
|
||||||
|
"is_true": self._is_true,
|
||||||
|
"is_false": self._is_false,
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply_filters(
|
||||||
|
self, search_results: list[dict[str, Any]], metadata_filters: MetadataFilters
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Apply metadata filters to a list of search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_results: List of result dictionaries, each containing 'metadata' field
|
||||||
|
metadata_filters: Dictionary of filter specifications
|
||||||
|
Format: {"field_name": {"operator": value}}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of search results
|
||||||
|
"""
|
||||||
|
if not metadata_filters:
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
logger.debug(f"Applying filters: {metadata_filters}")
|
||||||
|
logger.debug(f"Input results count: {len(search_results)}")
|
||||||
|
|
||||||
|
filtered_results = []
|
||||||
|
for result in search_results:
|
||||||
|
if self._evaluate_filters(result, metadata_filters):
|
||||||
|
filtered_results.append(result)
|
||||||
|
|
||||||
|
logger.debug(f"Filtered results count: {len(filtered_results)}")
|
||||||
|
return filtered_results
|
||||||
|
|
||||||
|
def _evaluate_filters(self, result: dict[str, Any], filters: MetadataFilters) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate all filters against a single search result.
|
||||||
|
|
||||||
|
All filters must pass (AND logic) for the result to be included.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Full search result dictionary (including metadata, text, etc.)
|
||||||
|
filters: Filter specifications to evaluate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if all filters pass, False otherwise
|
||||||
|
"""
|
||||||
|
for field_name, filter_spec in filters.items():
|
||||||
|
if not self._evaluate_field_filter(result, field_name, filter_spec):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _evaluate_field_filter(
|
||||||
|
self, result: dict[str, Any], field_name: str, filter_spec: FilterSpec
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate a single field filter against a search result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Full search result dictionary
|
||||||
|
field_name: Name of the field to filter on
|
||||||
|
filter_spec: Filter specification for this field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the filter passes, False otherwise
|
||||||
|
"""
|
||||||
|
# First check top-level fields, then check metadata
|
||||||
|
field_value = result.get(field_name)
|
||||||
|
if field_value is None:
|
||||||
|
# Try to get from metadata if not found at top level
|
||||||
|
metadata = result.get("metadata", {})
|
||||||
|
field_value = metadata.get(field_name)
|
||||||
|
|
||||||
|
# Handle missing fields - they fail all filters except existence checks
|
||||||
|
if field_value is None:
|
||||||
|
logger.debug(f"Field '{field_name}' not found in result or metadata")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Evaluate each operator in the filter spec
|
||||||
|
for operator, expected_value in filter_spec.items():
|
||||||
|
if operator not in self.operators:
|
||||||
|
logger.warning(f"Unsupported operator: {operator}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.operators[operator](field_value, expected_value):
|
||||||
|
logger.debug(
|
||||||
|
f"Filter failed: {field_name} {operator} {expected_value} "
|
||||||
|
f"(actual: {field_value})"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error evaluating filter {field_name} {operator} {expected_value}: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Comparison operators
|
||||||
|
def _equals(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value equals expected value."""
|
||||||
|
return field_value == expected_value
|
||||||
|
|
||||||
|
def _not_equals(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value does not equal expected value."""
|
||||||
|
return field_value != expected_value
|
||||||
|
|
||||||
|
def _less_than(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is less than expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a < b)
|
||||||
|
|
||||||
|
def _less_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is less than or equal to expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a <= b)
|
||||||
|
|
||||||
|
def _greater_than(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is greater than expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a > b)
|
||||||
|
|
||||||
|
def _greater_than_or_equal(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is greater than or equal to expected value."""
|
||||||
|
return self._numeric_compare(field_value, expected_value, lambda a, b: a >= b)
|
||||||
|
|
||||||
|
# Membership operators
|
||||||
|
def _in(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is in the expected list/collection."""
|
||||||
|
if not isinstance(expected_value, (list, tuple, set)):
|
||||||
|
raise ValueError("'in' operator requires a list, tuple, or set")
|
||||||
|
return field_value in expected_value
|
||||||
|
|
||||||
|
def _not_in(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is not in the expected list/collection."""
|
||||||
|
if not isinstance(expected_value, (list, tuple, set)):
|
||||||
|
raise ValueError("'not_in' operator requires a list, tuple, or set")
|
||||||
|
return field_value not in expected_value
|
||||||
|
|
||||||
|
# String operators
|
||||||
|
def _contains(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value contains the expected substring."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return expected_str in field_str
|
||||||
|
|
||||||
|
def _starts_with(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value starts with the expected prefix."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return field_str.startswith(expected_str)
|
||||||
|
|
||||||
|
def _ends_with(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value ends with the expected suffix."""
|
||||||
|
field_str = str(field_value)
|
||||||
|
expected_str = str(expected_value)
|
||||||
|
return field_str.endswith(expected_str)
|
||||||
|
|
||||||
|
# Boolean operators
|
||||||
|
def _is_true(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is truthy."""
|
||||||
|
return bool(field_value)
|
||||||
|
|
||||||
|
def _is_false(self, field_value: Any, expected_value: Any) -> bool:
|
||||||
|
"""Check if field value is falsy."""
|
||||||
|
return not bool(field_value)
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
def _numeric_compare(self, field_value: Any, expected_value: Any, compare_func) -> bool:
|
||||||
|
"""
|
||||||
|
Helper for numeric comparisons with type coercion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_value: Value from metadata
|
||||||
|
expected_value: Value to compare against
|
||||||
|
compare_func: Comparison function to apply
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result of comparison
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to convert both values to numbers for comparison
|
||||||
|
if isinstance(field_value, str) and isinstance(expected_value, str):
|
||||||
|
# String comparison if both are strings
|
||||||
|
return compare_func(field_value, expected_value)
|
||||||
|
|
||||||
|
# Numeric comparison - attempt to convert to float
|
||||||
|
field_num = (
|
||||||
|
float(field_value) if not isinstance(field_value, (int, float)) else field_value
|
||||||
|
)
|
||||||
|
expected_num = (
|
||||||
|
float(expected_value)
|
||||||
|
if not isinstance(expected_value, (int, float))
|
||||||
|
else expected_value
|
||||||
|
)
|
||||||
|
|
||||||
|
return compare_func(field_num, expected_num)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Fall back to string comparison if numeric conversion fails
|
||||||
|
return compare_func(str(field_value), str(expected_value))
|
||||||
@@ -41,6 +41,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
print("WARNING: embedding_model not found in meta.json. Recompute will fail.")
|
||||||
|
|
||||||
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
self.embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
|
self.embedding_options = self.meta.get("embedding_options", {})
|
||||||
|
|
||||||
self.embedding_server_manager = EmbeddingServerManager(
|
self.embedding_server_manager = EmbeddingServerManager(
|
||||||
backend_module_name=backend_module_name,
|
backend_module_name=backend_module_name,
|
||||||
@@ -77,6 +78,7 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
passages_file=passages_source_file,
|
passages_file=passages_source_file,
|
||||||
distance_metric=distance_metric,
|
distance_metric=distance_metric,
|
||||||
enable_warmup=kwargs.get("enable_warmup", False),
|
enable_warmup=kwargs.get("enable_warmup", False),
|
||||||
|
provider_options=self.embedding_options,
|
||||||
)
|
)
|
||||||
if not server_started:
|
if not server_started:
|
||||||
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
raise RuntimeError(f"Failed to start embedding server on port {actual_port}")
|
||||||
@@ -125,7 +127,12 @@ class BaseSearcher(LeannBackendSearcherInterface, ABC):
|
|||||||
from .embedding_compute import compute_embeddings
|
from .embedding_compute import compute_embeddings
|
||||||
|
|
||||||
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
embedding_mode = self.meta.get("embedding_mode", "sentence-transformers")
|
||||||
return compute_embeddings([query], self.embedding_model, embedding_mode)
|
return compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.embedding_model,
|
||||||
|
embedding_mode,
|
||||||
|
provider_options=self.embedding_options,
|
||||||
|
)
|
||||||
|
|
||||||
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
def _compute_embedding_via_server(self, chunks: list, zmq_port: int) -> np.ndarray:
|
||||||
"""Compute embeddings using the ZMQ embedding server."""
|
"""Compute embeddings using the ZMQ embedding server."""
|
||||||
|
|||||||
74
packages/leann-core/src/leann/settings.py
Normal file
74
packages/leann-core/src/leann/settings.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Runtime configuration helpers for LEANN."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Default fallbacks to preserve current behaviour while keeping them in one place.
|
||||||
|
_DEFAULT_OLLAMA_HOST = "http://localhost:11434"
|
||||||
|
_DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_url(value: str) -> str:
|
||||||
|
"""Normalize URL strings by stripping trailing slashes."""
|
||||||
|
|
||||||
|
return value.rstrip("/") if value else value
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_ollama_host(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the Ollama-compatible endpoint to use."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_LOCAL_LLM_HOST"),
|
||||||
|
os.getenv("LEANN_OLLAMA_HOST"),
|
||||||
|
os.getenv("OLLAMA_HOST"),
|
||||||
|
os.getenv("LOCAL_LLM_ENDPOINT"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_OLLAMA_HOST)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_openai_base_url(explicit: str | None = None) -> str:
|
||||||
|
"""Resolve the base URL for OpenAI-compatible services."""
|
||||||
|
|
||||||
|
candidates = (
|
||||||
|
explicit,
|
||||||
|
os.getenv("LEANN_OPENAI_BASE_URL"),
|
||||||
|
os.getenv("OPENAI_BASE_URL"),
|
||||||
|
os.getenv("LOCAL_OPENAI_BASE_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate:
|
||||||
|
return _clean_url(candidate)
|
||||||
|
|
||||||
|
return _clean_url(_DEFAULT_OPENAI_BASE_URL)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_openai_api_key(explicit: str | None = None) -> str | None:
|
||||||
|
"""Resolve the API key for OpenAI-compatible services."""
|
||||||
|
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
|
||||||
|
return os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_provider_options(options: dict[str, Any] | None) -> str | None:
|
||||||
|
"""Serialize provider options for child processes."""
|
||||||
|
|
||||||
|
if not options:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.dumps(options)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
# Fall back to empty payload if serialization fails
|
||||||
|
return None
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
Transform your development workflow with intelligent code assistance using LEANN's semantic search directly in Claude Code.
|
||||||
|
|
||||||
|
For agent-facing discovery details, see `llms.txt` in the repository root.
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
Install LEANN globally for MCP integration (with default backend):
|
Install LEANN globally for MCP integration (with default backend):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "leann"
|
name = "leann"
|
||||||
version = "0.3.0"
|
version = "0.3.4"
|
||||||
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
description = "LEANN - The smallest vector index in the world. RAG Everything with LEANN!"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@@ -46,6 +46,13 @@ dependencies = [
|
|||||||
"pathspec>=0.12.1",
|
"pathspec>=0.12.1",
|
||||||
"nbconvert>=7.16.6",
|
"nbconvert>=7.16.6",
|
||||||
"gitignore-parser>=0.1.12",
|
"gitignore-parser>=0.1.12",
|
||||||
|
# AST-aware code chunking dependencies
|
||||||
|
"astchunk>=0.1.0",
|
||||||
|
"tree-sitter>=0.20.0",
|
||||||
|
"tree-sitter-python>=0.20.0",
|
||||||
|
"tree-sitter-java>=0.20.0",
|
||||||
|
"tree-sitter-c-sharp>=0.20.0",
|
||||||
|
"tree-sitter-typescript>=0.20.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -92,6 +99,7 @@ wechat-exporter = "wechat_exporter.main:main"
|
|||||||
leann-core = { path = "packages/leann-core", editable = true }
|
leann-core = { path = "packages/leann-core", editable = true }
|
||||||
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
leann-backend-diskann = { path = "packages/leann-backend-diskann", editable = true }
|
||||||
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
leann-backend-hnsw = { path = "packages/leann-backend-hnsw", editable = true }
|
||||||
|
astchunk = { path = "packages/astchunk-leann", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|||||||
397
tests/test_astchunk_integration.py
Normal file
397
tests/test_astchunk_integration.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""
|
||||||
|
Test suite for astchunk integration with LEANN.
|
||||||
|
Tests AST-aware chunking functionality, language detection, and fallback mechanisms.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add apps directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "apps"))
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from chunking import (
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockDocument:
|
||||||
|
"""Mock LlamaIndex Document for testing."""
|
||||||
|
|
||||||
|
def __init__(self, content: str, file_path: str = "", metadata: Optional[dict] = None):
|
||||||
|
self.content = content
|
||||||
|
self.metadata = metadata or {}
|
||||||
|
if file_path:
|
||||||
|
self.metadata["file_path"] = file_path
|
||||||
|
|
||||||
|
def get_content(self) -> str:
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestCodeFileDetection:
|
||||||
|
"""Test code file detection and language mapping."""
|
||||||
|
|
||||||
|
def test_detect_code_files_python(self):
|
||||||
|
"""Test detection of Python files."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("print('hello')", "/path/to/file.py"),
|
||||||
|
MockDocument("This is text", "/path/to/file.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 1
|
||||||
|
assert len(text_docs) == 1
|
||||||
|
assert code_docs[0].metadata["language"] == "python"
|
||||||
|
assert code_docs[0].metadata["is_code"] is True
|
||||||
|
assert text_docs[0].metadata["is_code"] is False
|
||||||
|
|
||||||
|
def test_detect_code_files_multiple_languages(self):
|
||||||
|
"""Test detection of multiple programming languages."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def func():", "/path/to/script.py"),
|
||||||
|
MockDocument("public class Test {}", "/path/to/Test.java"),
|
||||||
|
MockDocument("interface ITest {}", "/path/to/test.ts"),
|
||||||
|
MockDocument("using System;", "/path/to/Program.cs"),
|
||||||
|
MockDocument("Regular text content", "/path/to/document.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 4
|
||||||
|
assert len(text_docs) == 1
|
||||||
|
|
||||||
|
languages = [doc.metadata["language"] for doc in code_docs]
|
||||||
|
assert "python" in languages
|
||||||
|
assert "java" in languages
|
||||||
|
assert "typescript" in languages
|
||||||
|
assert "csharp" in languages
|
||||||
|
|
||||||
|
def test_detect_code_files_no_file_path(self):
|
||||||
|
"""Test handling of documents without file paths."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("some content"),
|
||||||
|
MockDocument("other content", metadata={"some_key": "value"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
code_docs, text_docs = detect_code_files(docs)
|
||||||
|
|
||||||
|
assert len(code_docs) == 0
|
||||||
|
assert len(text_docs) == 2
|
||||||
|
for doc in text_docs:
|
||||||
|
assert doc.metadata["is_code"] is False
|
||||||
|
|
||||||
|
def test_get_language_from_extension(self):
|
||||||
|
"""Test language detection from file extensions."""
|
||||||
|
assert get_language_from_extension("test.py") == "python"
|
||||||
|
assert get_language_from_extension("Test.java") == "java"
|
||||||
|
assert get_language_from_extension("component.tsx") == "typescript"
|
||||||
|
assert get_language_from_extension("Program.cs") == "csharp"
|
||||||
|
assert get_language_from_extension("document.txt") is None
|
||||||
|
assert get_language_from_extension("") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestChunkingFunctions:
|
||||||
|
"""Test various chunking functionality."""
|
||||||
|
|
||||||
|
def test_create_traditional_chunks(self):
|
||||||
|
"""Test traditional text chunking."""
|
||||||
|
docs = [
|
||||||
|
MockDocument(
|
||||||
|
"This is a test document. It has multiple sentences. We want to test chunking."
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_traditional_chunks(docs, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_traditional_chunks_empty_docs(self):
|
||||||
|
"""Test traditional chunking with empty documents."""
|
||||||
|
chunks = create_traditional_chunks([], chunk_size=50, chunk_overlap=10)
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip astchunk tests in CI - dependency may not be available",
|
||||||
|
)
|
||||||
|
def test_create_ast_chunks_with_astchunk_available(self):
|
||||||
|
"""Test AST chunking when astchunk is available."""
|
||||||
|
python_code = '''
|
||||||
|
def hello_world():
|
||||||
|
"""Print hello world message."""
|
||||||
|
print("Hello, World!")
|
||||||
|
|
||||||
|
def add_numbers(a, b):
|
||||||
|
"""Add two numbers and return the result."""
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
class Calculator:
|
||||||
|
"""A simple calculator class."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.history = []
|
||||||
|
|
||||||
|
def add(self, a, b):
|
||||||
|
result = a + b
|
||||||
|
self.history.append(f"{a} + {b} = {result}")
|
||||||
|
return result
|
||||||
|
'''
|
||||||
|
|
||||||
|
docs = [MockDocument(python_code, "/test/calculator.py", {"language": "python"})]
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunks = create_ast_chunks(docs, max_chunk_size=200, chunk_overlap=50)
|
||||||
|
|
||||||
|
# Should have multiple chunks due to different functions/classes
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
assert all(len(chunk.strip()) > 0 for chunk in chunks)
|
||||||
|
|
||||||
|
# Check that code structure is somewhat preserved
|
||||||
|
combined_content = " ".join(chunks)
|
||||||
|
assert "def hello_world" in combined_content
|
||||||
|
assert "class Calculator" in combined_content
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
# astchunk not available, should fall back to traditional chunking
|
||||||
|
chunks = create_ast_chunks(docs, max_chunk_size=200, chunk_overlap=50)
|
||||||
|
assert len(chunks) > 0 # Should still get chunks from fallback
|
||||||
|
|
||||||
|
def test_create_ast_chunks_fallback_to_traditional(self):
|
||||||
|
"""Test AST chunking falls back to traditional when astchunk is not available."""
|
||||||
|
docs = [MockDocument("def test(): pass", "/test/script.py", {"language": "python"})]
|
||||||
|
|
||||||
|
# Mock astchunk import to fail
|
||||||
|
with patch("chunking.create_ast_chunks"):
|
||||||
|
# First call (actual test) should import astchunk and potentially fail
|
||||||
|
# Let's call the actual function to test the import error handling
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should return some chunks (either from astchunk or fallback)
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
def test_create_text_chunks_traditional_mode(self):
|
||||||
|
"""Test text chunking in traditional mode."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def test(): pass", "/test/script.py"),
|
||||||
|
MockDocument("This is regular text.", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_text_chunks(docs, use_ast_chunking=False, chunk_size=50, chunk_overlap=10)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_text_chunks_ast_mode(self):
|
||||||
|
"""Test text chunking in AST mode."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("def test(): pass", "/test/script.py"),
|
||||||
|
MockDocument("This is regular text.", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
docs,
|
||||||
|
use_ast_chunking=True,
|
||||||
|
ast_chunk_size=100,
|
||||||
|
ast_chunk_overlap=20,
|
||||||
|
chunk_size=50,
|
||||||
|
chunk_overlap=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) > 0
|
||||||
|
assert all(isinstance(chunk, str) for chunk in chunks)
|
||||||
|
|
||||||
|
def test_create_text_chunks_custom_extensions(self):
|
||||||
|
"""Test text chunking with custom code file extensions."""
|
||||||
|
docs = [
|
||||||
|
MockDocument("function test() {}", "/test/script.js"), # Not in default extensions
|
||||||
|
MockDocument("Regular text", "/test/doc.txt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# First without custom extensions - should treat .js as text
|
||||||
|
chunks_without = create_text_chunks(docs, use_ast_chunking=True, code_file_extensions=None)
|
||||||
|
|
||||||
|
# Then with custom extensions - should treat .js as code
|
||||||
|
chunks_with = create_text_chunks(
|
||||||
|
docs, use_ast_chunking=True, code_file_extensions=[".js", ".jsx"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Both should return chunks
|
||||||
|
assert len(chunks_without) > 0
|
||||||
|
assert len(chunks_with) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationWithDocumentRAG:
|
||||||
|
"""Integration tests with the document RAG system."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_code_dir(self):
|
||||||
|
"""Create a temporary directory with sample code files."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
temp_path = Path(temp_dir)
|
||||||
|
|
||||||
|
# Create sample Python file
|
||||||
|
python_file = temp_path / "example.py"
|
||||||
|
python_file.write_text('''
|
||||||
|
def fibonacci(n):
|
||||||
|
"""Calculate fibonacci number."""
|
||||||
|
if n <= 1:
|
||||||
|
return n
|
||||||
|
return fibonacci(n-1) + fibonacci(n-2)
|
||||||
|
|
||||||
|
class MathUtils:
|
||||||
|
@staticmethod
|
||||||
|
def factorial(n):
|
||||||
|
if n <= 1:
|
||||||
|
return 1
|
||||||
|
return n * MathUtils.factorial(n-1)
|
||||||
|
''')
|
||||||
|
|
||||||
|
# Create sample text file
|
||||||
|
text_file = temp_path / "readme.txt"
|
||||||
|
text_file.write_text("This is a sample text file for testing purposes.")
|
||||||
|
|
||||||
|
yield temp_path
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip integration tests in CI to avoid dependency issues",
|
||||||
|
)
|
||||||
|
def test_document_rag_with_ast_chunking(self, temp_code_dir):
|
||||||
|
"""Test document RAG with AST chunking enabled."""
|
||||||
|
with tempfile.TemporaryDirectory() as index_dir:
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"apps/document_rag.py",
|
||||||
|
"--llm",
|
||||||
|
"simulated",
|
||||||
|
"--embedding-model",
|
||||||
|
"facebook/contriever",
|
||||||
|
"--embedding-mode",
|
||||||
|
"sentence-transformers",
|
||||||
|
"--index-dir",
|
||||||
|
index_dir,
|
||||||
|
"--data-dir",
|
||||||
|
str(temp_code_dir),
|
||||||
|
"--enable-code-chunking",
|
||||||
|
"--query",
|
||||||
|
"How does the fibonacci function work?",
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||||
|
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=300, # 5 minutes
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should succeed even if astchunk is not available (fallback)
|
||||||
|
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||||
|
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
assert "Index saved to" in output or "Using existing index" in output
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip integration tests in CI to avoid dependency issues",
|
||||||
|
)
|
||||||
|
def test_code_rag_application(self, temp_code_dir):
|
||||||
|
"""Test the specialized code RAG application."""
|
||||||
|
with tempfile.TemporaryDirectory() as index_dir:
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"apps/code_rag.py",
|
||||||
|
"--llm",
|
||||||
|
"simulated",
|
||||||
|
"--embedding-model",
|
||||||
|
"facebook/contriever",
|
||||||
|
"--index-dir",
|
||||||
|
index_dir,
|
||||||
|
"--repo-dir",
|
||||||
|
str(temp_code_dir),
|
||||||
|
"--query",
|
||||||
|
"What classes are defined in this code?",
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||||
|
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300, env=env)
|
||||||
|
|
||||||
|
# Should succeed
|
||||||
|
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||||
|
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
assert "Using AST-aware chunking" in output or "traditional chunking" in output
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
pytest.skip("Test timed out - likely due to model download in CI")
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""Test error handling and edge cases."""
|
||||||
|
|
||||||
|
def test_text_chunking_empty_documents(self):
|
||||||
|
"""Test text chunking with empty document list."""
|
||||||
|
chunks = create_text_chunks([])
|
||||||
|
assert chunks == []
|
||||||
|
|
||||||
|
def test_text_chunking_invalid_parameters(self):
|
||||||
|
"""Test text chunking with invalid parameters."""
|
||||||
|
docs = [MockDocument("test content")]
|
||||||
|
|
||||||
|
# Should handle negative chunk sizes gracefully
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
docs, chunk_size=0, chunk_overlap=0, ast_chunk_size=0, ast_chunk_overlap=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still return some result
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
def test_create_ast_chunks_no_language(self):
|
||||||
|
"""Test AST chunking with documents missing language metadata."""
|
||||||
|
docs = [MockDocument("def test(): pass", "/test/script.py")] # No language set
|
||||||
|
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should fall back to traditional chunking
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
assert len(chunks) >= 0 # May be empty if fallback also fails
|
||||||
|
|
||||||
|
def test_create_ast_chunks_empty_content(self):
|
||||||
|
"""Test AST chunking with empty content."""
|
||||||
|
docs = [MockDocument("", "/test/script.py", {"language": "python"})]
|
||||||
|
|
||||||
|
chunks = create_ast_chunks(docs)
|
||||||
|
|
||||||
|
# Should handle empty content gracefully
|
||||||
|
assert isinstance(chunks, list)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
14
tests/test_cli_ask.py
Normal file
14
tests/test_cli_ask.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from leann.cli import LeannCLI
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_ask_accepts_positional_query(tmp_path, monkeypatch):
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
|
||||||
|
cli = LeannCLI()
|
||||||
|
parser = cli.create_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args(["ask", "my-docs", "Where are prompts configured?"])
|
||||||
|
|
||||||
|
assert args.command == "ask"
|
||||||
|
assert args.index_name == "my-docs"
|
||||||
|
assert args.query == "Where are prompts configured?"
|
||||||
@@ -57,6 +57,51 @@ def test_document_rag_simulated(test_data_dir):
|
|||||||
assert "This is a simulated answer" in output
|
assert "This is a simulated answer" in output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
os.environ.get("CI") == "true",
|
||||||
|
reason="Skip AST chunking tests in CI to avoid dependency issues",
|
||||||
|
)
|
||||||
|
def test_document_rag_with_ast_chunking(test_data_dir):
|
||||||
|
"""Test document_rag with AST-aware chunking enabled."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Use a subdirectory that doesn't exist yet to force index creation
|
||||||
|
index_dir = Path(temp_dir) / "test_ast_index"
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"apps/document_rag.py",
|
||||||
|
"--llm",
|
||||||
|
"simulated",
|
||||||
|
"--embedding-model",
|
||||||
|
"facebook/contriever",
|
||||||
|
"--embedding-mode",
|
||||||
|
"sentence-transformers",
|
||||||
|
"--index-dir",
|
||||||
|
str(index_dir),
|
||||||
|
"--data-dir",
|
||||||
|
str(test_data_dir),
|
||||||
|
"--enable-code-chunking", # Enable AST chunking
|
||||||
|
"--query",
|
||||||
|
"What is Pride and Prejudice about?",
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HF_HUB_DISABLE_SYMLINKS"] = "1"
|
||||||
|
env["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True, timeout=600, env=env)
|
||||||
|
|
||||||
|
# Check return code
|
||||||
|
assert result.returncode == 0, f"Command failed: {result.stderr}"
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
assert "Index saved to" in output or "Using existing index" in output
|
||||||
|
assert "This is a simulated answer" in output
|
||||||
|
|
||||||
|
# Should mention AST chunking if code files are present
|
||||||
|
# (might not be relevant for the test data, but command should succeed)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OpenAI API key not available")
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
os.environ.get("CI") == "true", reason="Skip OpenAI tests in CI to avoid API costs"
|
os.environ.get("CI") == "true", reason="Skip OpenAI tests in CI to avoid API costs"
|
||||||
|
|||||||
365
tests/test_metadata_filtering.py
Normal file
365
tests/test_metadata_filtering.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""
|
||||||
|
Comprehensive tests for metadata filtering functionality.
|
||||||
|
|
||||||
|
This module tests the MetadataFilterEngine class and its integration
|
||||||
|
with the LEANN search system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Import the modules we're testing
|
||||||
|
import sys
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||||
|
|
||||||
|
from leann.api import PassageManager, SearchResult
|
||||||
|
from leann.metadata_filter import MetadataFilterEngine
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataFilterEngine:
|
||||||
|
"""Test suite for the MetadataFilterEngine class."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures."""
|
||||||
|
self.engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
# Sample search results for testing
|
||||||
|
self.sample_results = [
|
||||||
|
{
|
||||||
|
"id": "doc1",
|
||||||
|
"score": 0.95,
|
||||||
|
"text": "This is chapter 1 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"tags": ["adventure", "fantasy"],
|
||||||
|
"word_count": 150,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc2",
|
||||||
|
"score": 0.87,
|
||||||
|
"text": "This is chapter 3 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 3,
|
||||||
|
"character": "Bob",
|
||||||
|
"tags": ["mystery", "thriller"],
|
||||||
|
"word_count": 250,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc3",
|
||||||
|
"score": 0.82,
|
||||||
|
"text": "This is chapter 5 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 5,
|
||||||
|
"character": "Alice",
|
||||||
|
"tags": ["romance", "drama"],
|
||||||
|
"word_count": 300,
|
||||||
|
"is_published": False,
|
||||||
|
"genre": "non-fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "doc4",
|
||||||
|
"score": 0.78,
|
||||||
|
"text": "This is chapter 10 content",
|
||||||
|
"metadata": {
|
||||||
|
"chapter": 10,
|
||||||
|
"character": "Charlie",
|
||||||
|
"tags": ["action", "adventure"],
|
||||||
|
"word_count": 400,
|
||||||
|
"is_published": True,
|
||||||
|
"genre": "fiction",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_engine_initialization(self):
|
||||||
|
"""Test that the filter engine initializes correctly."""
|
||||||
|
assert self.engine is not None
|
||||||
|
assert len(self.engine.operators) > 0
|
||||||
|
assert "==" in self.engine.operators
|
||||||
|
assert "contains" in self.engine.operators
|
||||||
|
assert "in" in self.engine.operators
|
||||||
|
|
||||||
|
def test_direct_instantiation(self):
|
||||||
|
"""Test direct instantiation of the engine."""
|
||||||
|
engine = MetadataFilterEngine()
|
||||||
|
assert isinstance(engine, MetadataFilterEngine)
|
||||||
|
|
||||||
|
def test_no_filters_returns_all_results(self):
|
||||||
|
"""Test that passing None or empty filters returns all results."""
|
||||||
|
# Test with None
|
||||||
|
result = self.engine.apply_filters(self.sample_results, None)
|
||||||
|
assert len(result) == len(self.sample_results)
|
||||||
|
|
||||||
|
# Test with empty dict
|
||||||
|
result = self.engine.apply_filters(self.sample_results, {})
|
||||||
|
assert len(result) == len(self.sample_results)
|
||||||
|
|
||||||
|
# Test comparison operators
|
||||||
|
def test_equals_filter(self):
|
||||||
|
"""Test equals (==) filter."""
|
||||||
|
filters = {"chapter": {"==": 1}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["id"] == "doc1"
|
||||||
|
|
||||||
|
def test_not_equals_filter(self):
|
||||||
|
"""Test not equals (!=) filter."""
|
||||||
|
filters = {"genre": {"!=": "fiction"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["genre"] == "non-fiction"
|
||||||
|
|
||||||
|
def test_less_than_filter(self):
|
||||||
|
"""Test less than (<) filter."""
|
||||||
|
filters = {"chapter": {"<": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
chapters = [r["metadata"]["chapter"] for r in result]
|
||||||
|
assert all(ch < 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_less_than_or_equal_filter(self):
|
||||||
|
"""Test less than or equal (<=) filter."""
|
||||||
|
filters = {"chapter": {"<=": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
chapters = [r["metadata"]["chapter"] for r in result]
|
||||||
|
assert all(ch <= 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_greater_than_filter(self):
|
||||||
|
"""Test greater than (>) filter."""
|
||||||
|
filters = {"word_count": {">": 200}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3 # Documents with word_count 250, 300, 400
|
||||||
|
word_counts = [r["metadata"]["word_count"] for r in result]
|
||||||
|
assert all(wc > 200 for wc in word_counts)
|
||||||
|
|
||||||
|
def test_greater_than_or_equal_filter(self):
|
||||||
|
"""Test greater than or equal (>=) filter."""
|
||||||
|
filters = {"word_count": {">=": 250}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
word_counts = [r["metadata"]["word_count"] for r in result]
|
||||||
|
assert all(wc >= 250 for wc in word_counts)
|
||||||
|
|
||||||
|
# Test membership operators
|
||||||
|
def test_in_filter(self):
|
||||||
|
"""Test in filter."""
|
||||||
|
filters = {"character": {"in": ["Alice", "Bob"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
characters = [r["metadata"]["character"] for r in result]
|
||||||
|
assert all(ch in ["Alice", "Bob"] for ch in characters)
|
||||||
|
|
||||||
|
def test_not_in_filter(self):
|
||||||
|
"""Test not_in filter."""
|
||||||
|
filters = {"character": {"not_in": ["Alice", "Bob"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["character"] == "Charlie"
|
||||||
|
|
||||||
|
# Test string operators
|
||||||
|
def test_contains_filter(self):
|
||||||
|
"""Test contains filter."""
|
||||||
|
filters = {"genre": {"contains": "fiction"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 4 # Both "fiction" and "non-fiction"
|
||||||
|
|
||||||
|
def test_starts_with_filter(self):
|
||||||
|
"""Test starts_with filter."""
|
||||||
|
filters = {"genre": {"starts_with": "non"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["metadata"]["genre"] == "non-fiction"
|
||||||
|
|
||||||
|
def test_ends_with_filter(self):
|
||||||
|
"""Test ends_with filter."""
|
||||||
|
filters = {"text": {"ends_with": "content"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 4 # All sample texts end with "content"
|
||||||
|
|
||||||
|
# Test boolean operators
|
||||||
|
def test_is_true_filter(self):
|
||||||
|
"""Test is_true filter."""
|
||||||
|
filters = {"is_published": {"is_true": True}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert all(r["metadata"]["is_published"] for r in result)
|
||||||
|
|
||||||
|
def test_is_false_filter(self):
|
||||||
|
"""Test is_false filter."""
|
||||||
|
filters = {"is_published": {"is_false": False}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert not result[0]["metadata"]["is_published"]
|
||||||
|
|
||||||
|
# Test compound filters (AND logic)
|
||||||
|
def test_compound_filters(self):
|
||||||
|
"""Test multiple filters applied together (AND logic)."""
|
||||||
|
filters = {"genre": {"==": "fiction"}, "chapter": {"<=": 5}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
assert r["metadata"]["genre"] == "fiction"
|
||||||
|
assert r["metadata"]["chapter"] <= 5
|
||||||
|
|
||||||
|
def test_multiple_operators_same_field(self):
|
||||||
|
"""Test multiple operators on the same field."""
|
||||||
|
filters = {"word_count": {">=": 200, "<=": 350}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
wc = r["metadata"]["word_count"]
|
||||||
|
assert 200 <= wc <= 350
|
||||||
|
|
||||||
|
# Test edge cases
|
||||||
|
def test_missing_field_fails_filter(self):
|
||||||
|
"""Test that missing metadata fields fail filters."""
|
||||||
|
filters = {"nonexistent_field": {"==": "value"}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_invalid_operator(self):
|
||||||
|
"""Test that invalid operators are handled gracefully."""
|
||||||
|
filters = {"chapter": {"invalid_op": 1}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 0 # Should filter out all results
|
||||||
|
|
||||||
|
def test_type_coercion_numeric(self):
|
||||||
|
"""Test numeric type coercion in comparisons."""
|
||||||
|
# Add a result with string chapter number
|
||||||
|
test_results = [
|
||||||
|
*self.sample_results,
|
||||||
|
{
|
||||||
|
"id": "doc5",
|
||||||
|
"score": 0.75,
|
||||||
|
"text": "String chapter test",
|
||||||
|
"metadata": {"chapter": "2", "genre": "test"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
filters = {"chapter": {"<": 3}}
|
||||||
|
result = self.engine.apply_filters(test_results, filters)
|
||||||
|
# Should include doc1 (chapter=1) and doc5 (chapter="2")
|
||||||
|
assert len(result) == 2
|
||||||
|
ids = [r["id"] for r in result]
|
||||||
|
assert "doc1" in ids
|
||||||
|
assert "doc5" in ids
|
||||||
|
|
||||||
|
def test_list_membership_with_nested_tags(self):
|
||||||
|
"""Test membership operations with list metadata."""
|
||||||
|
# Note: This tests the metadata structure, not list field filtering
|
||||||
|
# For list field filtering, we'd need to modify the test data
|
||||||
|
filters = {"character": {"in": ["Alice"]}}
|
||||||
|
result = self.engine.apply_filters(self.sample_results, filters)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(r["metadata"]["character"] == "Alice" for r in result)
|
||||||
|
|
||||||
|
def test_empty_results_list(self):
|
||||||
|
"""Test filtering on empty results list."""
|
||||||
|
filters = {"chapter": {"==": 1}}
|
||||||
|
result = self.engine.apply_filters([], filters)
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPassageManagerFiltering:
|
||||||
|
"""Test suite for PassageManager filtering integration."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures."""
|
||||||
|
# Mock the passage manager without actual file I/O
|
||||||
|
self.passage_manager = Mock(spec=PassageManager)
|
||||||
|
self.passage_manager.filter_engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
# Sample SearchResult objects
|
||||||
|
self.search_results = [
|
||||||
|
SearchResult(
|
||||||
|
id="doc1",
|
||||||
|
score=0.95,
|
||||||
|
text="Chapter 1 content",
|
||||||
|
metadata={"chapter": 1, "character": "Alice"},
|
||||||
|
),
|
||||||
|
SearchResult(
|
||||||
|
id="doc2",
|
||||||
|
score=0.87,
|
||||||
|
text="Chapter 5 content",
|
||||||
|
metadata={"chapter": 5, "character": "Bob"},
|
||||||
|
),
|
||||||
|
SearchResult(
|
||||||
|
id="doc3",
|
||||||
|
score=0.82,
|
||||||
|
text="Chapter 10 content",
|
||||||
|
metadata={"chapter": 10, "character": "Alice"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_search_result_filtering(self):
|
||||||
|
"""Test filtering SearchResult objects."""
|
||||||
|
# Create a real PassageManager instance just for the filtering method
|
||||||
|
# We'll mock the file operations
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
filters = {"chapter": {"<=": 5}}
|
||||||
|
result = pm.filter_search_results(self.search_results, filters)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
chapters = [r.metadata["chapter"] for r in result]
|
||||||
|
assert all(ch <= 5 for ch in chapters)
|
||||||
|
|
||||||
|
def test_filter_search_results_no_filters(self):
|
||||||
|
"""Test that None filters return all results."""
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
result = pm.filter_search_results(self.search_results, None)
|
||||||
|
assert len(result) == len(self.search_results)
|
||||||
|
|
||||||
|
def test_filter_maintains_search_result_type(self):
|
||||||
|
"""Test that filtering returns SearchResult objects."""
|
||||||
|
with patch("builtins.open"), patch("json.loads"), patch("pickle.load"):
|
||||||
|
pm = PassageManager([{"type": "jsonl", "path": "test.jsonl"}])
|
||||||
|
|
||||||
|
filters = {"character": {"==": "Alice"}}
|
||||||
|
result = pm.filter_search_results(self.search_results, filters)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
for r in result:
|
||||||
|
assert isinstance(r, SearchResult)
|
||||||
|
assert r.metadata["character"] == "Alice"
|
||||||
|
|
||||||
|
|
||||||
|
# Integration tests would go here, but they require actual LEANN backend setup
|
||||||
|
# These would test the full pipeline from LeannSearcher.search() with metadata_filters
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run basic smoke tests
|
||||||
|
engine = MetadataFilterEngine()
|
||||||
|
|
||||||
|
sample_data = [
|
||||||
|
{
|
||||||
|
"id": "test1",
|
||||||
|
"score": 0.9,
|
||||||
|
"text": "Test content",
|
||||||
|
"metadata": {"chapter": 1, "published": True},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test basic filtering
|
||||||
|
result = engine.apply_filters(sample_data, {"chapter": {"==": 1}})
|
||||||
|
assert len(result) == 1
|
||||||
|
print("✅ Basic filtering test passed")
|
||||||
|
|
||||||
|
result = engine.apply_filters(sample_data, {"chapter": {"==": 2}})
|
||||||
|
assert len(result) == 0
|
||||||
|
print("✅ No match filtering test passed")
|
||||||
|
|
||||||
|
print("🎉 All smoke tests passed!")
|
||||||
Reference in New Issue
Block a user