Compare commits
236 Commits
apps
...
fix-update
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd5c052bd8 | ||
|
|
2f77d0185c | ||
|
|
82d536b2ae | ||
|
|
e2b37914ce | ||
|
|
e588100674 | ||
|
|
f42e086383 | ||
|
|
fecee94af1 | ||
|
|
01475c10a0 | ||
|
|
c8aa063f48 | ||
|
|
576beb13db | ||
|
|
63c7b0c8a3 | ||
|
|
ec889f7ef4 | ||
|
|
322e5c162d | ||
|
|
edde0cdeb2 | ||
|
|
db7ba27ff6 | ||
|
|
5f7806e16f | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
4e5b73ce7b | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
a6f07a54f1 | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 | ||
|
|
ef980d70b3 | ||
|
|
db3c63c441 | ||
|
|
00eeadb9dd | ||
|
|
42c8370709 | ||
|
|
fafdf8fcbe | ||
|
|
21f7d8e031 | ||
|
|
46565b9249 | ||
|
|
3dad76126a | ||
|
|
18e28bda32 | ||
|
|
609fa62fd5 | ||
|
|
eab13434ef | ||
|
|
b2390ccc14 | ||
|
|
e8fca2c84a | ||
|
|
790ae14f69 | ||
|
|
ac363072e6 | ||
|
|
93465af46c | ||
|
|
792ece67dc | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
3ff5aac8e0 | ||
|
|
67fef60466 | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
5567302316 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c96d653072 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 | ||
|
|
cf2ef48967 | ||
|
|
0692bbf7a2 | ||
|
|
52584a171f | ||
|
|
efd6b5324b | ||
|
|
2baaa4549b | ||
|
|
35310ddd52 | ||
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
|||||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report a bug in LEANN
|
||||||
|
labels: ["bug"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: A clear description of the bug
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to reproduce
|
||||||
|
placeholder: |
|
||||||
|
1. Install with...
|
||||||
|
2. Run command...
|
||||||
|
3. See error
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: error
|
||||||
|
attributes:
|
||||||
|
label: Error message
|
||||||
|
description: Paste any error messages
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: version
|
||||||
|
attributes:
|
||||||
|
label: LEANN Version
|
||||||
|
placeholder: "0.1.0"
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Docker
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Documentation
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||||
|
about: Read the docs first
|
||||||
|
- name: Discussions
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||||
|
about: Ask questions and share ideas
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a new feature for LEANN
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: What problem does this solve?
|
||||||
|
description: Describe the problem or need
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: How would you like this to work?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: example
|
||||||
|
attributes:
|
||||||
|
label: Example usage
|
||||||
|
description: Show how the API might look
|
||||||
|
render: python
|
||||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!-- Brief description of your changes -->
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Tests pass (`uv run pytest`)
|
||||||
|
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||||
|
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||||
12
.github/workflows/build-and-publish.yml
vendored
Normal file
12
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
451
.github/workflows/build-reusable.yml
vendored
Normal file
451
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
name: Reusable Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: 'Git ref to build'
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint and Format Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install uv and Python
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Run pre-commit with only lint group (no project deps)
|
||||||
|
run: |
|
||||||
|
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||||
|
|
||||||
|
|
||||||
|
build:
|
||||||
|
needs: lint
|
||||||
|
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.13'
|
||||||
|
# ARM64 Linux builds
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.12'
|
||||||
|
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||||
|
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install uv and Python
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
patchelf
|
||||||
|
|
||||||
|
# Debug: Show system information
|
||||||
|
echo "🔍 System Information:"
|
||||||
|
echo "Architecture: $(uname -m)"
|
||||||
|
echo "OS: $(uname -a)"
|
||||||
|
echo "CPU info: $(lscpu | head -5)"
|
||||||
|
|
||||||
|
# Install math library based on architecture
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||||
|
|
||||||
|
if [[ "$ARCH" == "x86_64" ]]; then
|
||||||
|
# Install Intel MKL for DiskANN on x86_64
|
||||||
|
echo "📦 Installing Intel MKL for x86_64..."
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||||
|
echo "✅ Intel MKL installed for x86_64"
|
||||||
|
|
||||||
|
# Debug: Check MKL installation
|
||||||
|
echo "🔍 MKL Installation Check:"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||||
|
|
||||||
|
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||||
|
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||||
|
echo "📦 Installing OpenBLAS for ARM64..."
|
||||||
|
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||||
|
echo "✅ OpenBLAS installed for ARM64"
|
||||||
|
|
||||||
|
# Debug: Check OpenBLAS installation
|
||||||
|
echo "🔍 OpenBLAS Installation Check:"
|
||||||
|
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||||
|
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Debug: Show final library paths
|
||||||
|
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Don't install LLVM, use system clang for better compatibility
|
||||||
|
brew install libomp boost protobuf zeromq
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv python install ${{ matrix.python }}
|
||||||
|
uv venv --python ${{ matrix.python }} .uv-build
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
BUILD_PY=".uv-build\\Scripts\\python.exe"
|
||||||
|
else
|
||||||
|
BUILD_PY=".uv-build/bin/python"
|
||||||
|
fi
|
||||||
|
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --python "$BUILD_PY" auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --python "$BUILD_PY" delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
|
||||||
|
else
|
||||||
|
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Set macOS environment variables
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Use brew --prefix to automatically detect Homebrew installation path
|
||||||
|
HOMEBREW_PREFIX=$(brew --prefix)
|
||||||
|
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||||
|
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
||||||
|
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# Set compiler flags for OpenMP (required for both backends)
|
||||||
|
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
||||||
|
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Build packages
|
||||||
|
run: |
|
||||||
|
# Build core (platform independent)
|
||||||
|
cd packages/leann-core
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build HNSW backend
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||||
|
# Use system clang for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# Homebrew libraries on each macOS version require matching minimum version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
else
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build DiskANN backend
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||||
|
# Use system clang for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
|
# But Homebrew libraries on each macOS version require matching minimum version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
else
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build meta package (platform independent)
|
||||||
|
cd packages/leann
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Determine deployment target based on runner OS
|
||||||
|
# Must match the Homebrew libraries for each macOS version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
HNSW_TARGET="13.0"
|
||||||
|
DISKANN_TARGET="13.3"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
HNSW_TARGET="14.0"
|
||||||
|
DISKANN_TARGET="14.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
HNSW_TARGET="15.0"
|
||||||
|
DISKANN_TARGET="15.0"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: List built packages
|
||||||
|
run: |
|
||||||
|
echo "📦 Built packages:"
|
||||||
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
|
||||||
|
- name: Install built packages for testing
|
||||||
|
run: |
|
||||||
|
# Create uv-managed virtual environment with the requested interpreter
|
||||||
|
uv python install ${{ matrix.python }}
|
||||||
|
uv venv --python ${{ matrix.python }}
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
UV_PY=".venv\\Scripts\\python.exe"
|
||||||
|
else
|
||||||
|
UV_PY=".venv/bin/python"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install test dependency group only (avoids reinstalling project package)
|
||||||
|
uv pip install --python "$UV_PY" --group test
|
||||||
|
|
||||||
|
# Install core wheel built in this job
|
||||||
|
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||||
|
if [[ -n "$CORE_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$CORE_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||||
|
if [[ -z "$HNSW_WHL" ]]; then
|
||||||
|
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||||
|
fi
|
||||||
|
if [[ -n "$HNSW_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$HNSW_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
|
||||||
|
fi
|
||||||
|
|
||||||
|
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||||
|
if [[ -z "$DISKANN_WHL" ]]; then
|
||||||
|
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||||
|
fi
|
||||||
|
if [[ -n "$DISKANN_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$DISKANN_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
|
||||||
|
fi
|
||||||
|
|
||||||
|
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||||
|
if [[ -n "$LEANN_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$LEANN_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
env:
|
||||||
|
CI: true
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
|
TOKENIZERS_PARALLELISM: false
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
pytest tests/ -v --tb=short
|
||||||
|
|
||||||
|
- name: Run sanity checks (optional)
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run distance function tests if available
|
||||||
|
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||||
|
echo "Running distance function sanity checks..."
|
||||||
|
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
|
path: packages/*/dist/
|
||||||
|
|
||||||
|
|
||||||
|
arch-smoke:
|
||||||
|
name: Arch Linux smoke test (install & import)
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: archlinux:latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Prepare system
|
||||||
|
run: |
|
||||||
|
pacman -Syu --noconfirm
|
||||||
|
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||||
|
|
||||||
|
- name: Download ALL wheel artifacts from this run
|
||||||
|
uses: actions/download-artifact@v5
|
||||||
|
with:
|
||||||
|
# Don't specify name, download all artifacts
|
||||||
|
path: ./wheels
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Create virtual environment and install wheels
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
uv pip install --find-links wheels leann-core
|
||||||
|
uv pip install --find-links wheels leann-backend-hnsw
|
||||||
|
uv pip install --find-links wheels leann-backend-diskann
|
||||||
|
uv pip install --find-links wheels leann
|
||||||
|
|
||||||
|
- name: Import & tiny runtime check
|
||||||
|
env:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
python - <<'PY'
|
||||||
|
import leann
|
||||||
|
import leann_backend_hnsw as h
|
||||||
|
import leann_backend_diskann as d
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
b = LeannBuilder(backend_name="hnsw")
|
||||||
|
b.add_text("hello arch")
|
||||||
|
b.build_index("arch_demo.leann")
|
||||||
|
s = LeannSearcher("arch_demo.leann")
|
||||||
|
print("search:", s.search("hello", top_k=1))
|
||||||
|
PY
|
||||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
name: Link Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 3 * * 1"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
link-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: lycheeverse/lychee-action@v2
|
||||||
|
with:
|
||||||
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
129
.github/workflows/release-manual.yml
vendored
Normal file
129
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Version to release (e.g., 0.1.2)'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-version:
|
||||||
|
name: Update Version
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
outputs:
|
||||||
|
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Validate version
|
||||||
|
run: |
|
||||||
|
# Remove 'v' prefix if present for validation
|
||||||
|
VERSION_CLEAN="${{ inputs.version }}"
|
||||||
|
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||||
|
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||||
|
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
- name: Update versions and push
|
||||||
|
id: push
|
||||||
|
run: |
|
||||||
|
# Check current version
|
||||||
|
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||||
|
echo "Current version: $CURRENT_VERSION"
|
||||||
|
echo "Target version: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||||
|
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
else
|
||||||
|
./scripts/bump_version.sh ${{ inputs.version }}
|
||||||
|
git config user.name "GitHub Actions"
|
||||||
|
git config user.email "actions@github.com"
|
||||||
|
git add packages/*/pyproject.toml
|
||||||
|
git commit -m "chore: release v${{ inputs.version }}"
|
||||||
|
git push origin main
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
build-packages:
|
||||||
|
name: Build packages
|
||||||
|
needs: update-version
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish and Release
|
||||||
|
needs: [update-version, build-packages]
|
||||||
|
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: dist-artifacts
|
||||||
|
|
||||||
|
- name: Collect packages
|
||||||
|
run: |
|
||||||
|
mkdir -p dist
|
||||||
|
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||||
|
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||||
|
|
||||||
|
echo "📦 Packages to publish:"
|
||||||
|
ls -la dist/
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
|
echo "❌ PYPI_API_TOKEN not configured!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
pip install twine
|
||||||
|
twine upload dist/* --skip-existing --verbose
|
||||||
|
|
||||||
|
echo "✅ Published to PyPI!"
|
||||||
|
|
||||||
|
- name: Create release
|
||||||
|
run: |
|
||||||
|
# Check if tag already exists
|
||||||
|
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||||
|
else
|
||||||
|
git tag "v${{ inputs.version }}"
|
||||||
|
git push origin "v${{ inputs.version }}"
|
||||||
|
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if release already exists
|
||||||
|
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||||
|
else
|
||||||
|
gh release create "v${{ inputs.version }}" \
|
||||||
|
--title "Release v${{ inputs.version }}" \
|
||||||
|
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||||
|
--latest
|
||||||
|
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
30
.gitignore
vendored
30
.gitignore
vendored
@@ -12,16 +12,18 @@ outputs/
|
|||||||
*.idx
|
*.idx
|
||||||
*.map
|
*.map
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
*.eml
|
*.eml
|
||||||
*.emlx
|
*.emlx
|
||||||
*.json
|
*.json
|
||||||
|
*.png
|
||||||
|
!.vscode/*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
|
!llms.txt
|
||||||
latency_breakdown*.json
|
latency_breakdown*.json
|
||||||
experiment_results/eval_results/diskann/*.json
|
experiment_results/eval_results/diskann/*.json
|
||||||
aws/
|
aws/
|
||||||
@@ -35,11 +37,15 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
examples/data/*
|
data/*
|
||||||
!examples/data/2501.14312v1 (1).pdf
|
!data/2501.14312v1 (1).pdf
|
||||||
!examples/data/2506.08276v1.pdf
|
!data/2506.08276v1.pdf
|
||||||
!examples/data/PrideandPrejudice.txt
|
!data/PrideandPrejudice.txt
|
||||||
!examples/data/README.md
|
!data/huawei_pangu.md
|
||||||
|
!data/ground_truth/
|
||||||
|
!data/indices/
|
||||||
|
!data/queries/
|
||||||
|
!data/.gitattributes
|
||||||
*.qdstrm
|
*.qdstrm
|
||||||
benchmark_results/
|
benchmark_results/
|
||||||
results/
|
results/
|
||||||
@@ -87,3 +93,15 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
*.passages.json
|
*.passages.json
|
||||||
|
|
||||||
batchtest.py
|
batchtest.py
|
||||||
|
tests/__pytest_cache__/
|
||||||
|
tests/__pycache__/
|
||||||
|
benchmarks/data/
|
||||||
|
|
||||||
|
## multi vector
|
||||||
|
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||||
|
|
||||||
|
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
|
||||||
|
# If you need to commit a specific demo PDF, remove this negation locally.
|
||||||
|
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||||
|
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||||
|
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
url = https://github.com/zeromq/libzmq.git
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
[submodule "packages/astchunk-leann"]
|
||||||
|
path = packages/astchunk-leann
|
||||||
|
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||||
|
|||||||
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: debug-statements
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
|
- id: ruff-format
|
||||||
5
.vscode/extensions.json
vendored
Normal file
5
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"recommendations": [
|
||||||
|
"charliermarsh.ruff",
|
||||||
|
]
|
||||||
|
}
|
||||||
22
.vscode/settings.json
vendored
Normal file
22
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||||
|
"python.terminal.activateEnvironment": true,
|
||||||
|
"[python]": {
|
||||||
|
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||||
|
"editor.formatOnSave": true,
|
||||||
|
"editor.codeActionsOnSave": {
|
||||||
|
"source.organizeImports": "explicit",
|
||||||
|
"source.fixAll": "explicit"
|
||||||
|
},
|
||||||
|
"editor.insertSpaces": true,
|
||||||
|
"editor.tabSize": 4
|
||||||
|
},
|
||||||
|
"ruff.enable": true,
|
||||||
|
"files.watcherExclude": {
|
||||||
|
"**/.venv/**": true,
|
||||||
|
"**/__pycache__/**": true,
|
||||||
|
"**/*.egg-info/**": true,
|
||||||
|
"**/build/**": true,
|
||||||
|
"**/dist/**": true
|
||||||
|
}
|
||||||
|
}
|
||||||
387
apps/base_rag_example.py
Normal file
387
apps/base_rag_example.py
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
"""
|
||||||
|
Base class for unified RAG examples interface.
|
||||||
|
Provides common parameters and functionality for all RAG examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRAGExample(ABC):
|
||||||
|
"""Base class for all RAG examples with unified interface."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
default_index_name: str,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.default_index_name = default_index_name
|
||||||
|
self.parser = self._create_parser()
|
||||||
|
|
||||||
|
def _create_parser(self) -> argparse.ArgumentParser:
|
||||||
|
"""Create argument parser with common parameters."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Core parameters (all examples share these)
|
||||||
|
core_group = parser.add_argument_group("Core Parameters")
|
||||||
|
core_group.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default=f"./{self.default_index_name}",
|
||||||
|
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Query to run (if not provided, will run in interactive mode)",
|
||||||
|
)
|
||||||
|
# Allow subclasses to override default max_items
|
||||||
|
max_items_default = getattr(self, "max_items_default", -1)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--max-items",
|
||||||
|
type=int,
|
||||||
|
default=max_items_default,
|
||||||
|
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding parameters
|
||||||
|
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||||
|
# Allow subclasses to override default embedding_model
|
||||||
|
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default=embedding_model_default,
|
||||||
|
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM parameters
|
||||||
|
llm_group = parser.add_argument_group("LLM Parameters")
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="openai",
|
||||||
|
choices=["openai", "ollama", "hf", "simulated"],
|
||||||
|
help="LLM backend: openai, ollama, or hf (default: openai)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--thinking-budget",
|
||||||
|
type=str,
|
||||||
|
choices=["low", "medium", "high"],
|
||||||
|
default=None,
|
||||||
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# AST Chunking parameters
|
||||||
|
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--use-ast-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum characters per AST chunk (default: 512)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-overlap",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Overlap between AST chunks (default: 64)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--code-file-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-fallback-traditional",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search parameters
|
||||||
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
|
search_group.add_argument(
|
||||||
|
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||||
|
)
|
||||||
|
search_group.add_argument(
|
||||||
|
"--search-complexity",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Search complexity for graph traversal (default: 64)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index building parameters
|
||||||
|
index_group = parser.add_argument_group("Index Building Parameters")
|
||||||
|
index_group.add_argument(
|
||||||
|
"--backend-name",
|
||||||
|
type=str,
|
||||||
|
default="hnsw",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
help="Backend to use for index (default: hnsw)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--graph-degree",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Graph degree for index construction (default: 32)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--build-complexity",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Build complexity for index construction (default: 64)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-compact",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable compact index storage",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable embedding recomputation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add source-specific parameters
|
||||||
|
self._add_specific_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add source-specific arguments. Override in subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load data from the source. Returns list of text chunks."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_llm_config(self, args) -> dict[str, Any]:
|
||||||
|
"""Get LLM configuration based on arguments."""
|
||||||
|
config = {"type": args.llm}
|
||||||
|
|
||||||
|
if args.llm == "openai":
|
||||||
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||||
|
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||||
|
if resolved_key:
|
||||||
|
config["api_key"] = resolved_key
|
||||||
|
elif args.llm == "ollama":
|
||||||
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
|
config["host"] = resolve_ollama_host(args.llm_host)
|
||||||
|
elif args.llm == "hf":
|
||||||
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
elif args.llm == "simulated":
|
||||||
|
# Simulated LLM doesn't need additional configuration
|
||||||
|
pass
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def build_index(self, args, texts: list[str]) -> str:
|
||||||
|
"""Build LEANN index from texts."""
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
|
||||||
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend_name,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
complexity=args.build_complexity,
|
||||||
|
is_compact=not args.no_compact,
|
||||||
|
is_recompute=not args.no_recompute,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add texts in batches for better progress tracking
|
||||||
|
batch_size = 1000
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
for text in batch:
|
||||||
|
builder.add_text(text)
|
||||||
|
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||||
|
|
||||||
|
print("Building index structure...")
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"Index saved to: {index_path}")
|
||||||
|
|
||||||
|
# Register project directory so leann list can discover this index
|
||||||
|
# The index is saved as args.index_dir/index_name.leann
|
||||||
|
# We want to register the current working directory where the app is run
|
||||||
|
register_project_directory(Path.cwd())
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def run_interactive_chat(self, args, index_path: str):
|
||||||
|
"""Run interactive chat with the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||||
|
print("Type 'quit' or 'exit' to stop.\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("You: ").strip()
|
||||||
|
if query.lower() in ["quit", "exit", "q"]:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
|
llm_kwargs = {}
|
||||||
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
)
|
||||||
|
print(f"\nAssistant: {response}\n")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
async def run_single_query(self, args, index_path: str, query: str):
|
||||||
|
"""Run a single query against the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||||
|
|
||||||
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
|
llm_kwargs = {}
|
||||||
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||||
|
)
|
||||||
|
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main entry point for the example."""
|
||||||
|
args = self.parser.parse_args()
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
index_exists = Path(args.index_dir).exists()
|
||||||
|
|
||||||
|
if not index_exists or args.force_rebuild:
|
||||||
|
# Load data and build index
|
||||||
|
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||||
|
texts = await self.load_data(args)
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
print("No data found to index!")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_path = await self.build_index(args, texts)
|
||||||
|
else:
|
||||||
|
print(f"\nUsing existing index in {args.index_dir}")
|
||||||
|
|
||||||
|
# Run query or interactive mode
|
||||||
|
if args.query:
|
||||||
|
await self.run_single_query(args, index_path, args.query)
|
||||||
|
else:
|
||||||
|
await self.run_interactive_chat(args, index_path)
|
||||||
171
apps/browser_rag.py
Normal file
171
apps/browser_rag.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
Browser History RAG example using the unified interface.
|
||||||
|
Supports Chrome browser history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Chrome browser history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Browser History",
|
||||||
|
description="Process and query Chrome browser history with LEANN",
|
||||||
|
default_index_name="google_history_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add browser-specific arguments."""
|
||||||
|
browser_group = parser.add_argument_group("Browser Parameters")
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chrome-profile",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--auto-find-profiles",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Automatically find all Chrome profiles (default: True)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_chrome_base_path(self) -> Path:
|
||||||
|
"""Get the base Chrome profile path based on OS."""
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||||
|
elif sys.platform.startswith("linux"):
|
||||||
|
return Path.home() / ".config" / "google-chrome"
|
||||||
|
elif sys.platform == "win32":
|
||||||
|
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||||
|
|
||||||
|
def _find_chrome_profiles(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Chrome profiles."""
|
||||||
|
base_path = self._get_chrome_base_path()
|
||||||
|
if not base_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
profiles = []
|
||||||
|
|
||||||
|
# Check Default profile
|
||||||
|
default_profile = base_path / "Default"
|
||||||
|
if default_profile.exists() and (default_profile / "History").exists():
|
||||||
|
profiles.append(default_profile)
|
||||||
|
|
||||||
|
# Check numbered profiles
|
||||||
|
for item in base_path.iterdir():
|
||||||
|
if item.is_dir() and item.name.startswith("Profile "):
|
||||||
|
if (item / "History").exists():
|
||||||
|
profiles.append(item)
|
||||||
|
|
||||||
|
return profiles
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load browser history and convert to text chunks."""
|
||||||
|
# Determine Chrome profiles
|
||||||
|
if args.chrome_profile and not args.auto_find_profiles:
|
||||||
|
profile_dirs = [Path(args.chrome_profile)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Chrome profiles...")
|
||||||
|
profile_dirs = self._find_chrome_profiles()
|
||||||
|
|
||||||
|
# If specific profile given, filter to just that one
|
||||||
|
if args.chrome_profile:
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||||
|
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found!")
|
||||||
|
print("Please specify --chrome-profile manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
# Process each profile
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per profile
|
||||||
|
max_per_profile = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_profile = remaining
|
||||||
|
|
||||||
|
# Load history
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir),
|
||||||
|
max_count=max_per_profile,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} history entries from this profile")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No browser history found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for browser history RAG
|
||||||
|
print("\n🌐 Browser History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What websites did I visit about machine learning?'")
|
||||||
|
print("- 'Find my search history about programming'")
|
||||||
|
print("- 'What YouTube videos did I watch recently?'")
|
||||||
|
print("- 'Show me websites about travel planning'")
|
||||||
|
print("\nNote: Make sure Chrome is closed before running\n")
|
||||||
|
|
||||||
|
rag = BrowserRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
44
apps/chunking/__init__.py
Normal file
44
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Unified chunking utilities facade.
|
||||||
|
|
||||||
|
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||||
|
that both repo apps (importing `chunking`) and installed wheels share one
|
||||||
|
single implementation. When running from the repo without installation, it
|
||||||
|
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||||
|
if leann_src.exists():
|
||||||
|
sys.path.insert(0, str(leann_src))
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CODE_EXTENSIONS",
|
||||||
|
"create_ast_chunks",
|
||||||
|
"create_text_chunks",
|
||||||
|
"create_traditional_chunks",
|
||||||
|
"detect_code_files",
|
||||||
|
"get_language_from_extension",
|
||||||
|
]
|
||||||
211
apps/code_rag.py
Normal file
211
apps/code_rag.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||||
|
Specialized for code repositories with automatic language detection and
|
||||||
|
optimized chunking parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRAG(BaseRAGExample):
|
||||||
|
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Code",
|
||||||
|
description="Process and query code repositories with AST-aware chunking",
|
||||||
|
default_index_name="code_index",
|
||||||
|
)
|
||||||
|
# Override defaults for code-specific usage
|
||||||
|
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||||
|
self.max_items_default = -1 # Process all code files by default
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add code-specific arguments."""
|
||||||
|
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||||
|
|
||||||
|
code_group.add_argument(
|
||||||
|
"--repo-dir",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Code repository directory to index (default: current directory)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=list(CODE_EXTENSIONS.keys()),
|
||||||
|
help="File extensions to include (default: supported code extensions)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--exclude-dirs",
|
||||||
|
nargs="+",
|
||||||
|
default=[
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
".venv",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"target",
|
||||||
|
],
|
||||||
|
help="Directories to exclude from indexing",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--max-file-size",
|
||||||
|
type=int,
|
||||||
|
default=1000000, # 1MB
|
||||||
|
help="Maximum file size in bytes to process (default: 1MB)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-comments",
|
||||||
|
action="store_true",
|
||||||
|
help="Include comments in chunking (useful for documentation)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--preserve-imports",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Try to preserve import statements in chunks (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load code files and convert to AST-aware chunks."""
|
||||||
|
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||||
|
print(f"📁 Including extensions: {args.include_extensions}")
|
||||||
|
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||||
|
|
||||||
|
# Check if repository directory exists
|
||||||
|
repo_path = Path(args.repo_dir)
|
||||||
|
if not repo_path.exists():
|
||||||
|
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||||
|
|
||||||
|
# Load code files with filtering
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"required_exts": args.include_extensions,
|
||||||
|
"exclude_hidden": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create exclusion filter
|
||||||
|
def file_filter(file_path: str) -> bool:
|
||||||
|
"""Filter out unwanted files and directories."""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# Check file size
|
||||||
|
try:
|
||||||
|
if path.stat().st_size > args.max_file_size:
|
||||||
|
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if in excluded directory
|
||||||
|
for exclude_dir in args.exclude_dirs:
|
||||||
|
if exclude_dir in path.parts:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load documents with file filtering
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.repo_dir,
|
||||||
|
file_extractor=None, # Use default extractors
|
||||||
|
**reader_kwargs,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
|
# Apply custom filtering
|
||||||
|
filtered_docs = []
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_filter(file_path):
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
documents = filtered_docs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error loading code files: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(
|
||||||
|
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"✅ Loaded {len(documents)} code files")
|
||||||
|
|
||||||
|
# Show breakdown by language/extension
|
||||||
|
ext_counts = {}
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||||
|
|
||||||
|
print("📊 Files by extension:")
|
||||||
|
for ext, count in sorted(ext_counts.items()):
|
||||||
|
print(f" {ext}: {count} files")
|
||||||
|
|
||||||
|
# Use AST-aware chunking by default for code
|
||||||
|
print(
|
||||||
|
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=256, # Fallback for non-code files
|
||||||
|
chunk_overlap=64,
|
||||||
|
use_ast_chunking=True, # Always use AST for code RAG
|
||||||
|
ast_chunk_size=args.ast_chunk_size,
|
||||||
|
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||||
|
code_file_extensions=args.include_extensions,
|
||||||
|
ast_fallback_traditional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for code RAG
|
||||||
|
print("\n💻 Code RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'How does the embedding computation work?'")
|
||||||
|
print("- 'What are the main classes in this codebase?'")
|
||||||
|
print("- 'Show me the search implementation'")
|
||||||
|
print("- 'How is error handling implemented?'")
|
||||||
|
print("- 'What design patterns are used?'")
|
||||||
|
print("- 'Explain the chunking logic'")
|
||||||
|
print("\n🚀 Features:")
|
||||||
|
print("- ✅ AST-aware chunking preserves code structure")
|
||||||
|
print("- ✅ Automatic language detection")
|
||||||
|
print("- ✅ Smart filtering of large files and common excludes")
|
||||||
|
print("- ✅ Optimized for code understanding")
|
||||||
|
print("\nUsage examples:")
|
||||||
|
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||||
|
print(
|
||||||
|
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||||
|
)
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = CodeRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
131
apps/document_rag.py
Normal file
131
apps/document_rag.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
Document RAG example using the unified interface.
|
||||||
|
Supports PDF, TXT, MD, and other document formats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRAG(BaseRAGExample):
|
||||||
|
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Document",
|
||||||
|
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||||
|
default_index_name="test_doc_files",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add document-specific arguments."""
|
||||||
|
doc_group = parser.add_argument_group("Document Parameters")
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
type=str,
|
||||||
|
default="data",
|
||||||
|
help="Directory containing documents to index (default: data)",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--file-types",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--enable-code-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files in the data directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load documents and convert to text chunks."""
|
||||||
|
print(f"Loading documents from: {args.data_dir}")
|
||||||
|
if args.file_types:
|
||||||
|
print(f"Filtering by file types: {args.file_types}")
|
||||||
|
else:
|
||||||
|
print("Processing all supported file types")
|
||||||
|
|
||||||
|
# Check if data directory exists
|
||||||
|
data_path = Path(args.data_dir)
|
||||||
|
if not data_path.exists():
|
||||||
|
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
}
|
||||||
|
if args.file_types:
|
||||||
|
reader_kwargs["required_exts"] = args.file_types
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} documents")
|
||||||
|
|
||||||
|
# Determine chunking strategy
|
||||||
|
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||||
|
|
||||||
|
if use_ast:
|
||||||
|
print("Using AST-aware chunking for code files")
|
||||||
|
|
||||||
|
# Convert to text chunks with optional AST support
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
use_ast_chunking=use_ast,
|
||||||
|
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||||
|
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||||
|
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||||
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for document RAG
|
||||||
|
print("\n📄 Document RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What are the main techniques LEANN uses?'")
|
||||||
|
print("- 'What is the technique DLPM?'")
|
||||||
|
print("- 'Who does Elizabeth Bennet marry?'")
|
||||||
|
print(
|
||||||
|
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||||
|
)
|
||||||
|
print("\n🚀 NEW: Code-aware chunking available!")
|
||||||
|
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||||
|
print("- Supports Python, Java, C#, TypeScript files")
|
||||||
|
print("- Better semantic understanding of code structure")
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = DocumentRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
167
apps/email_data/LEANN_email_reader.py
Normal file
167
apps/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Recursively find all 'Messages' directories under the given root.
|
||||||
|
Returns a list of Path objects.
|
||||||
|
"""
|
||||||
|
if root is None:
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
messages_dirs = []
|
||||||
|
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||||
|
if os.path.basename(dirpath) == "Messages":
|
||||||
|
messages_dirs.append(Path(dirpath))
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
|
"""
|
||||||
|
self.include_html = include_html
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
total_files = 0
|
||||||
|
successful_files = 0
|
||||||
|
failed_files = 0
|
||||||
|
|
||||||
|
print(f"Starting to process directory: {input_dir}")
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
# Check if we've reached the max count (skip if max_count == -1)
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
total_files += 1
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/plain"
|
||||||
|
or part.get_content_type() == "text/html"
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/html"
|
||||||
|
and not self.include_html
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = part.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body += payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding payload: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
payload = msg.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body = payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding single part payload: {e}")
|
||||||
|
body = ""
|
||||||
|
|
||||||
|
# Only create document if we have some content
|
||||||
|
if body.strip() or subject != "No Subject":
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[File]: {filename}
|
||||||
|
[From]: {from_addr}
|
||||||
|
[To]: {to_addr}
|
||||||
|
[Subject]: {subject}
|
||||||
|
[Date]: {date}
|
||||||
|
[EMAIL BODY Start]:
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
successful_files += 1
|
||||||
|
|
||||||
|
# Print first few successful files for debugging
|
||||||
|
if successful_files <= 3:
|
||||||
|
print(
|
||||||
|
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Processing summary:")
|
||||||
|
print(f" Total .emlx files found: {total_files}")
|
||||||
|
print(f" Successfully loaded: {successful_files}")
|
||||||
|
print(f" Failed to load: {failed_files}")
|
||||||
|
print(f" Final documents: {len(docs)}")
|
||||||
|
|
||||||
|
return docs
|
||||||
@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
from fsspec import AbstractFileSystem
|
|
||||||
|
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.core.schema import Document
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
@@ -27,11 +27,7 @@ class MboxReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MESSAGE_FORMAT: str = (
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
"Date: {_date}\n"
|
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||||
"From: {_from}\n"
|
|
||||||
"To: {_to}\n"
|
|
||||||
"Subject: {_subject}\n"
|
|
||||||
"Content: {_content}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -45,9 +41,7 @@ class MboxReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
from bs4 import BeautifulSoup # noqa
|
from bs4 import BeautifulSoup # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
@@ -56,9 +50,9 @@ class MboxReader(BaseReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
file: Path,
|
file: Path,
|
||||||
extra_info: Optional[Dict] = None,
|
extra_info: dict | None = None,
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Parse file into string."""
|
"""Parse file into string."""
|
||||||
# Import required libraries
|
# Import required libraries
|
||||||
import mailbox
|
import mailbox
|
||||||
@@ -74,7 +68,7 @@ class MboxReader(BaseReader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
results: List[str] = []
|
results: list[str] = []
|
||||||
# Load file using mailbox
|
# Load file using mailbox
|
||||||
bytes_parser = BytesParser(policy=default).parse
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
@@ -134,12 +128,12 @@ class EmlxMboxReader(MboxReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
directory: Path,
|
directory: Path,
|
||||||
extra_info: Optional[Dict] = None,
|
extra_info: dict | None = None,
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
import tempfile
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
if fs:
|
if fs:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -156,18 +150,18 @@ class EmlxMboxReader(MboxReader):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# Create a temporary mbox file
|
# Create a temporary mbox file
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||||
temp_mbox_path = temp_mbox.name
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
# Convert .emlx files to mbox format
|
# Convert .emlx files to mbox format
|
||||||
for emlx_file in emlx_files:
|
for emlx_file in emlx_files:
|
||||||
try:
|
try:
|
||||||
# Read the .emlx file
|
# Read the .emlx file
|
||||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# .emlx format: first line is length, rest is email content
|
# .emlx format: first line is length, rest is email content
|
||||||
lines = content.split('\n', 1)
|
lines = content.split("\n", 1)
|
||||||
if len(lines) >= 2:
|
if len(lines) >= 2:
|
||||||
email_content = lines[1] # Skip the length line
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
@@ -188,5 +182,5 @@ class EmlxMboxReader(MboxReader):
|
|||||||
# Clean up temporary file
|
# Clean up temporary file
|
||||||
try:
|
try:
|
||||||
os.unlink(temp_mbox_path)
|
os.unlink(temp_mbox_path)
|
||||||
except:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
157
apps/email_rag.py
Normal file
157
apps/email_rag.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
Email RAG example using the unified interface.
|
||||||
|
Supports Apple Mail on macOS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Apple Mail processing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Process all emails by default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Email",
|
||||||
|
description="Process and query Apple Mail emails with LEANN",
|
||||||
|
default_index_name="mail_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add email-specific arguments."""
|
||||||
|
email_group = parser.add_argument_group("Email Parameters")
|
||||||
|
email_group.add_argument(
|
||||||
|
"--mail-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_mail_directories(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Apple Mail directories."""
|
||||||
|
mail_base = Path.home() / "Library" / "Mail"
|
||||||
|
if not mail_base.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find all Messages directories
|
||||||
|
messages_dirs = []
|
||||||
|
for item in mail_base.rglob("Messages"):
|
||||||
|
if item.is_dir():
|
||||||
|
messages_dirs.append(item)
|
||||||
|
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load emails and convert to text chunks."""
|
||||||
|
# Determine mail directories
|
||||||
|
if args.mail_path:
|
||||||
|
messages_dirs = [Path(args.mail_path)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Apple Mail directories...")
|
||||||
|
messages_dirs = self._find_mail_directories()
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Apple Mail directories found!")
|
||||||
|
print("Please specify --mail-path manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(messages_dirs)} mail directories")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = EmlxReader(include_html=args.include_html)
|
||||||
|
|
||||||
|
# Process each directory
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Count emlx files
|
||||||
|
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||||
|
print(f"Found {len(emlx_files)} email files")
|
||||||
|
|
||||||
|
# Apply max_items limit per directory
|
||||||
|
max_per_dir = -1 # Default to process all
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_dir = remaining
|
||||||
|
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||||
|
|
||||||
|
# Load emails - fix the parameter passing
|
||||||
|
documents = reader.load_data(
|
||||||
|
input_dir=str(messages_dir),
|
||||||
|
max_count=max_per_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} emails from this directory")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No emails found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
# Email reader uses chunk_overlap=25 as in original
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||||
|
print(" Windows/Linux support coming soon!\n")
|
||||||
|
|
||||||
|
# Example queries for email RAG
|
||||||
|
print("\n📧 Email RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What did my boss say about deadlines?'")
|
||||||
|
print("- 'Find emails about travel expenses'")
|
||||||
|
print("- 'Show me emails from last month about the project'")
|
||||||
|
print("- 'What food did I order from DoorDash?'")
|
||||||
|
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||||
|
|
||||||
|
rag = EmailRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
from .history import ChromeHistoryReader
|
from .history import ChromeHistoryReader
|
||||||
|
|
||||||
__all__ = ['ChromeHistoryReader']
|
__all__ = ["ChromeHistoryReader"]
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
import sqlite3
|
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
class ChromeHistoryReader(BaseReader):
|
class ChromeHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
@@ -17,7 +19,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load Chrome history data from the default Chrome profile location.
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
@@ -27,13 +29,15 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
max_count (int): Maximum amount of history entries to read.
|
max_count (int): Maximum amount of history entries to read.
|
||||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||||
|
|
||||||
# Default Chrome profile path on macOS
|
# Default Chrome profile path on macOS
|
||||||
if chrome_profile_path is None:
|
if chrome_profile_path is None:
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
|
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
@@ -70,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
@@ -93,12 +97,17 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading Chrome history: {e}")
|
print(f"Error reading Chrome history: {e}")
|
||||||
|
# add you may need to close your browser to make the database file available
|
||||||
|
# also highlight in red
|
||||||
|
print(
|
||||||
|
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||||
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_chrome_profiles() -> List[Path]:
|
def find_chrome_profiles() -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find all Chrome profile directories.
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
@@ -124,7 +133,9 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
return profile_dirs
|
return profile_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
def export_history_to_file(
|
||||||
|
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export Chrome history to a text file using the same SQL query format.
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
@@ -132,7 +143,9 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
output_file: Path to the output file
|
output_file: Path to the output file
|
||||||
max_count: Maximum number of entries to export
|
max_count: Maximum number of entries to export
|
||||||
"""
|
"""
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
if not os.path.exists(history_db_path):
|
if not os.path.exists(history_db_path):
|
||||||
@@ -159,10 +172,12 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
cursor.execute(query, (max_count,))
|
cursor.execute(query, (max_count,))
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
f.write(
|
||||||
|
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||||
|
)
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
@@ -2,13 +2,14 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
class WeChatHistoryReader(BaseReader):
|
class WeChatHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
@@ -43,10 +44,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
if not wechattweak_path.exists():
|
if not wechattweak_path.exists():
|
||||||
print("Downloading WeChatTweak CLI...")
|
print("Downloading WeChatTweak CLI...")
|
||||||
subprocess.run([
|
subprocess.run(
|
||||||
"curl", "-L", "-o", str(wechattweak_path),
|
[
|
||||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
"curl",
|
||||||
], check=True)
|
"-L",
|
||||||
|
"-o",
|
||||||
|
str(wechattweak_path),
|
||||||
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Make executable
|
# Make executable
|
||||||
wechattweak_path.chmod(0o755)
|
wechattweak_path.chmod(0o755)
|
||||||
@@ -73,16 +80,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
def check_api_available(self) -> bool:
|
def check_api_available(self) -> bool:
|
||||||
"""Check if WeChatTweak API is available."""
|
"""Check if WeChatTweak API is available."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
||||||
], capture_output=True, text=True, timeout=5)
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
return result.returncode == 0 and result.stdout.strip()
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_readable_text(self, content: str) -> str:
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Extract readable text from message content, removing XML and system messages.
|
Extract readable text from message content, removing XML and system messages.
|
||||||
@@ -100,14 +107,14 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Extract text from dictionary structure
|
# Extract text from dictionary structure
|
||||||
text_parts = []
|
text_parts = []
|
||||||
if 'title' in content:
|
if "title" in content:
|
||||||
text_parts.append(str(content['title']))
|
text_parts.append(str(content["title"]))
|
||||||
if 'quoted' in content:
|
if "quoted" in content:
|
||||||
text_parts.append(str(content['quoted']))
|
text_parts.append(str(content["quoted"]))
|
||||||
if 'content' in content:
|
if "content" in content:
|
||||||
text_parts.append(str(content['content']))
|
text_parts.append(str(content["content"]))
|
||||||
if 'text' in content:
|
if "text" in content:
|
||||||
text_parts.append(str(content['text']))
|
text_parts.append(str(content["text"]))
|
||||||
|
|
||||||
if text_parts:
|
if text_parts:
|
||||||
return " | ".join(text_parts)
|
return " | ".join(text_parts)
|
||||||
@@ -120,11 +127,11 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Remove common prefixes like "wxid_xxx:\n"
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
# If it's just XML or system message, return empty
|
# If it's just XML or system message, return empty
|
||||||
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return clean_content.strip()
|
return clean_content.strip()
|
||||||
@@ -145,9 +152,9 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
# Handle dictionary content
|
# Handle dictionary content
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Check if dict has any readable text fields
|
# Check if dict has any readable text fields
|
||||||
text_fields = ['title', 'quoted', 'content', 'text']
|
text_fields = ["title", "quoted", "content", "text"]
|
||||||
for field in text_fields:
|
for field in text_fields:
|
||||||
if field in content and content[field]:
|
if content.get(field):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -156,42 +163,47 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip image messages (contain XML with img tags)
|
# Skip image messages (contain XML with img tags)
|
||||||
if '<img' in content and 'cdnurl' in content:
|
if "<img" in content and "cdnurl" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip emoji messages (contain emoji XML tags)
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
if '<emoji' in content and 'productid' in content:
|
if "<emoji" in content and "productid" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip voice messages
|
# Skip voice messages
|
||||||
if '<voice' in content:
|
if "<voice" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip video messages
|
# Skip video messages
|
||||||
if '<video' in content:
|
if "<video" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip file messages
|
# Skip file messages
|
||||||
if '<appmsg' in content and 'appid' in content:
|
if "<appmsg" in content and "appid" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip system messages (like "recalled a message")
|
# Skip system messages (like "recalled a message")
|
||||||
if 'recalled a message' in content:
|
if "recalled a message" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if there's actual readable text (not just XML or system messages)
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
# If after cleaning we have meaningful text, consider it readable
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
def _concatenate_messages(
|
||||||
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
max_length: int = 128,
|
||||||
|
time_window_minutes: int = 30,
|
||||||
|
overlap_messages: int = 0,
|
||||||
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Concatenate messages based on length and time rules.
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
@@ -214,12 +226,12 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
# Extract message info
|
# Extract message info
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
from_user = message.get('fromUser', '')
|
message.get("fromUser", "")
|
||||||
to_user = message.get('toUser', '')
|
message.get("toUser", "")
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
@@ -236,16 +248,24 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if time_diff_minutes > time_window_minutes:
|
if time_diff_minutes > time_window_minutes:
|
||||||
# Time gap too large, start new group
|
# Time gap too large, start new group
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
@@ -254,16 +274,24 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
message_length = len(readable_text)
|
message_length = len(readable_text)
|
||||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
# Current group would exceed max length, save it and start new
|
# Current group would exceed max length, save it and start new
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
@@ -275,16 +303,18 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
# Add the last group if it exists
|
# Add the last group if it exists
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return concatenated_groups
|
return concatenated_groups
|
||||||
|
|
||||||
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Create concatenated content from a group of messages.
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
@@ -295,16 +325,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
Returns:
|
Returns:
|
||||||
Formatted concatenated content
|
Formatted concatenated content
|
||||||
"""
|
"""
|
||||||
messages = message_group['messages']
|
messages = message_group["messages"]
|
||||||
start_time = message_group['start_time']
|
start_time = message_group["start_time"]
|
||||||
end_time = message_group['end_time']
|
end_time = message_group["end_time"]
|
||||||
|
|
||||||
# Format timestamps
|
# Format timestamps
|
||||||
if start_time:
|
if start_time:
|
||||||
try:
|
try:
|
||||||
start_timestamp = datetime.fromtimestamp(start_time)
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
start_time_str = str(start_time)
|
start_time_str = str(start_time)
|
||||||
else:
|
else:
|
||||||
start_time_str = "Unknown"
|
start_time_str = "Unknown"
|
||||||
@@ -312,8 +342,8 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if end_time:
|
if end_time:
|
||||||
try:
|
try:
|
||||||
end_timestamp = datetime.fromtimestamp(end_time)
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
end_time_str = str(end_time)
|
end_time_str = str(end_time)
|
||||||
else:
|
else:
|
||||||
end_time_str = "Unknown"
|
end_time_str = "Unknown"
|
||||||
@@ -321,10 +351,10 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
# Build concatenated message content
|
# Build concatenated message content
|
||||||
message_parts = []
|
message_parts = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
@@ -336,8 +366,8 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
# change to YYYY-MM-DD HH:MM:SS
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -351,7 +381,7 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
Contact: {contact_name}
|
Contact: {contact_name}
|
||||||
Time Range: {start_time_str} - {end_time_str}
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||||
|
|
||||||
{concatenated_text}
|
{concatenated_text}
|
||||||
"""
|
"""
|
||||||
@@ -361,7 +391,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
"""
|
"""
|
||||||
return doc_content, contact_name
|
return doc_content, contact_name
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load WeChat chat history data from exported JSON files.
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
@@ -376,13 +406,13 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||||
include_non_text = load_kwargs.get('include_non_text', False)
|
include_non_text = load_kwargs.get("include_non_text", False)
|
||||||
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||||
max_length = load_kwargs.get('max_length', 1000)
|
max_length = load_kwargs.get("max_length", 1000)
|
||||||
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||||
|
|
||||||
# Default WeChat export path
|
# Default WeChat export path
|
||||||
if wechat_export_dir is None:
|
if wechat_export_dir is None:
|
||||||
@@ -403,7 +433,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, 'r', encoding='utf-8') as f:
|
with open(json_file, encoding="utf-8") as f:
|
||||||
chat_data = json.load(f)
|
chat_data = json.load(f)
|
||||||
|
|
||||||
# Extract contact name from filename
|
# Extract contact name from filename
|
||||||
@@ -414,7 +444,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
readable_messages = []
|
readable_messages = []
|
||||||
for message in chat_data:
|
for message in chat_data:
|
||||||
try:
|
try:
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
if not include_non_text and not self._is_text_message(content):
|
if not include_non_text and not self._is_text_message(content):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -430,9 +460,9 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
# Concatenate messages based on rules
|
# Concatenate messages based on rules
|
||||||
message_groups = self._concatenate_messages(
|
message_groups = self._concatenate_messages(
|
||||||
readable_messages,
|
readable_messages,
|
||||||
max_length=-1,
|
max_length=max_length,
|
||||||
time_window_minutes=-1,
|
time_window_minutes=time_window_minutes,
|
||||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
overlap_messages=0, # No overlap between groups
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create documents from concatenated groups
|
# Create documents from concatenated groups
|
||||||
@@ -440,12 +470,19 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
doc_content, contact_name = self._create_concatenated_content(
|
||||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
message_group, contact_name
|
||||||
|
)
|
||||||
|
doc = Document(
|
||||||
|
text=doc_content,
|
||||||
|
metadata={"contact_name": contact_name},
|
||||||
|
)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
print(
|
||||||
|
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Original single-message processing
|
# Original single-message processing
|
||||||
@@ -454,12 +491,12 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Extract message information
|
# Extract message information
|
||||||
from_user = message.get('fromUser', '')
|
message.get("fromUser", "")
|
||||||
to_user = message.get('toUser', '')
|
message.get("toUser", "")
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Handle content that might be dict or string
|
# Handle content that might be dict or string
|
||||||
try:
|
try:
|
||||||
@@ -480,8 +517,8 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -495,7 +532,9 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={})
|
doc = Document(
|
||||||
|
text=doc_content, metadata={"contact_name": contact_name}
|
||||||
|
)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@@ -512,7 +551,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_wechat_export_dirs() -> List[Path]:
|
def find_wechat_export_dirs() -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find all WeChat export directories.
|
Find all WeChat export directories.
|
||||||
|
|
||||||
@@ -523,10 +562,10 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
|
|
||||||
# Look for common export directory names
|
# Look for common export directory names
|
||||||
possible_dirs = [
|
possible_dirs = [
|
||||||
Path("./wechat_export_test"),
|
|
||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export")
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir in possible_dirs:
|
for export_dir in possible_dirs:
|
||||||
@@ -534,13 +573,20 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
json_files = list(export_dir.glob("*.json"))
|
json_files = list(export_dir.glob("*.json"))
|
||||||
if json_files:
|
if json_files:
|
||||||
export_dirs.append(export_dir)
|
export_dirs.append(export_dir)
|
||||||
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
print(
|
||||||
|
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
return export_dirs
|
return export_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
def export_chat_to_file(
|
||||||
|
output_file: str = "wechat_chat_export.txt",
|
||||||
|
max_count: int = 1000,
|
||||||
|
export_dir: str | None = None,
|
||||||
|
include_non_text: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history to a text file.
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
@@ -560,14 +606,14 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
try:
|
try:
|
||||||
json_files = list(Path(export_dir).glob("*.json"))
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
count = 0
|
count = 0
|
||||||
for json_file in json_files:
|
for json_file in json_files:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, 'r', encoding='utf-8') as json_f:
|
with open(json_file, encoding="utf-8") as json_f:
|
||||||
chat_data = json.load(json_f)
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
contact_name = json_file.stem
|
contact_name = json_file.stem
|
||||||
@@ -577,10 +623,10 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
from_user = message.get('fromUser', '')
|
from_user = message.get("fromUser", "")
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
|
|
||||||
# Skip non-text messages unless requested
|
# Skip non-text messages unless requested
|
||||||
if not include_non_text:
|
if not include_non_text:
|
||||||
@@ -595,8 +641,8 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -613,7 +659,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error exporting WeChat chat history: {e}")
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history using wechat-exporter tool.
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
@@ -642,16 +688,21 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
if requirements_file.exists():
|
if requirements_file.exists():
|
||||||
print("Installing wechat-exporter requirements...")
|
print("Installing wechat-exporter requirements...")
|
||||||
subprocess.run([
|
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
||||||
"uv", "pip", "install", "-r", str(requirements_file)
|
|
||||||
], check=True)
|
|
||||||
|
|
||||||
# Run the export command
|
# Run the export command
|
||||||
print("Running wechat-exporter...")
|
print("Running wechat-exporter...")
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
[
|
||||||
"export-all", str(export_path)
|
sys.executable,
|
||||||
], capture_output=True, text=True, check=True)
|
str(self.wechat_exporter_dir / "main.py"),
|
||||||
|
"export-all",
|
||||||
|
str(export_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
print("Export command output:")
|
print("Export command output:")
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
@@ -662,7 +713,9 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
# Check if export was successful
|
# Check if export was successful
|
||||||
if export_path.exists() and any(export_path.glob("*.json")):
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
json_files = list(export_path.glob("*.json"))
|
json_files = list(export_path.glob("*.json"))
|
||||||
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
print(
|
||||||
|
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
||||||
|
)
|
||||||
return export_path
|
return export_path
|
||||||
else:
|
else:
|
||||||
print("Export completed but no JSON files found")
|
print("Export completed but no JSON files found")
|
||||||
@@ -678,7 +731,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find existing WeChat exports or create new ones.
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
@@ -697,7 +750,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
Path("./wechat_export_direct"),
|
Path("./wechat_export_direct"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export")
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir_path in possible_export_dirs:
|
for export_dir_path in possible_export_dirs:
|
||||||
@@ -714,6 +767,8 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if exported_path:
|
if exported_path:
|
||||||
export_dirs = [exported_path]
|
export_dirs = [exported_path]
|
||||||
else:
|
else:
|
||||||
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
print(
|
||||||
|
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
||||||
|
)
|
||||||
|
|
||||||
return export_dirs
|
return export_dirs
|
||||||
113
apps/multimodal/vision-based-pdf-multi-vector/README.md
Normal file
113
apps/multimodal/vision-based-pdf-multi-vector/README.md
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
|
||||||
|
|
||||||
|
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
|
||||||
|
|
||||||
|
### What you’ll run
|
||||||
|
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
|
||||||
|
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
|
||||||
|
|
||||||
|
## Prerequisites (macOS)
|
||||||
|
|
||||||
|
### 1) Homebrew poppler (for pdf2image)
|
||||||
|
```bash
|
||||||
|
brew install poppler
|
||||||
|
which pdfinfo && pdfinfo -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2) Python environment
|
||||||
|
Use uv (recommended) or pip. Python 3.9+.
|
||||||
|
|
||||||
|
Using uv:
|
||||||
|
```bash
|
||||||
|
uv pip install \
|
||||||
|
colpali_engine \
|
||||||
|
pdf2image \
|
||||||
|
pillow \
|
||||||
|
matplotlib qwen_vl_utils \
|
||||||
|
einops \
|
||||||
|
seaborn
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- On first run, models download from Hugging Face. Login/config if needed.
|
||||||
|
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
|
||||||
|
```bash
|
||||||
|
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run the demos
|
||||||
|
|
||||||
|
### A) Local PDF example
|
||||||
|
Converts a local PDF into page images, embeds them, builds an index, and searches.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||||
|
# If you don't have the sample PDF locally, download it (ignored by Git)
|
||||||
|
mkdir -p pdfs
|
||||||
|
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
|
||||||
|
ls pdfs/2004.12832v2.pdf
|
||||||
|
# Ensure output dir exists
|
||||||
|
mkdir -p pages
|
||||||
|
python multi-vector-leann-paper-example.py
|
||||||
|
```
|
||||||
|
Expected:
|
||||||
|
- Page images in `pages/`.
|
||||||
|
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
|
||||||
|
|
||||||
|
To use your own PDF: edit `pdf_path` near the top of the script.
|
||||||
|
|
||||||
|
### B) Similarity map + answer demo
|
||||||
|
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||||
|
python multi-vector-leann-similarity-map.py
|
||||||
|
```
|
||||||
|
Artifacts (when enabled):
|
||||||
|
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
|
||||||
|
- Similarity maps: `./figures/similarity_map_rank{K}.png`
|
||||||
|
|
||||||
|
Key knobs in the script (top of file):
|
||||||
|
- `QUERY`: your question
|
||||||
|
- `MODEL`: `"colqwen2"` or `"colpali"`
|
||||||
|
- `USE_HF_DATASET`: set `False` to use local pages
|
||||||
|
- `PDF`, `PAGES_DIR`: for local mode
|
||||||
|
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
|
||||||
|
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
|
||||||
|
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
|
||||||
|
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
|
||||||
|
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
|
||||||
|
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
|
||||||
|
- For local PDFs, page images go to `./pages/`.
|
||||||
|
|
||||||
|
|
||||||
|
### Retrieval and Visualization Example
|
||||||
|
|
||||||
|
Example settings in `multi-vector-leann-similarity-map.py`:
|
||||||
|
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
|
||||||
|
- `SIMILARITY_MAP = True` (to generate heatmaps)
|
||||||
|
- `TOPK = 1` (save the top retrieved page and its similarity map)
|
||||||
|
|
||||||
|
Run:
|
||||||
|
```bash
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||||
|
python multi-vector-leann-similarity-map.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Outputs (by default):
|
||||||
|
- Retrieved page: `./figures/retrieved_page_rank1.png`
|
||||||
|
- Similarity map: `./figures/similarity_map_rank1.png`
|
||||||
|
|
||||||
|
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
|
||||||
|
"):
|
||||||
|

|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
|
||||||
|
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.
|
||||||
BIN
apps/multimodal/vision-based-pdf-multi-vector/fig/image.png
Normal file
BIN
apps/multimodal/vision-based-pdf-multi-vector/fig/image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 166 KiB |
@@ -0,0 +1,182 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class LeannMultiVector:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index_path: str,
|
||||||
|
dim: int = 128,
|
||||||
|
distance_metric: str = "mips",
|
||||||
|
m: int = 16,
|
||||||
|
ef_construction: int = 500,
|
||||||
|
is_compact: bool = False,
|
||||||
|
is_recompute: bool = False,
|
||||||
|
embedding_model_name: str = "colvision",
|
||||||
|
) -> None:
|
||||||
|
self.index_path = index_path
|
||||||
|
self.dim = dim
|
||||||
|
self.embedding_model_name = embedding_model_name
|
||||||
|
self._pending_items: list[dict] = []
|
||||||
|
self._backend_kwargs = {
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"M": m,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
"is_compact": is_compact,
|
||||||
|
"is_recompute": is_recompute,
|
||||||
|
}
|
||||||
|
self._labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
def _meta_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": "hnsw",
|
||||||
|
"embedding_model": self.embedding_model_name,
|
||||||
|
"embedding_mode": "custom",
|
||||||
|
"dimensions": self.dim,
|
||||||
|
"backend_kwargs": self._backend_kwargs,
|
||||||
|
"is_compact": self._backend_kwargs.get("is_compact", True),
|
||||||
|
"is_pruned": self._backend_kwargs.get("is_compact", True)
|
||||||
|
and self._backend_kwargs.get("is_recompute", True),
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_collection(self) -> None:
|
||||||
|
path = Path(self.index_path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def insert(self, data: dict) -> None:
|
||||||
|
self._pending_items.append(
|
||||||
|
{
|
||||||
|
"doc_id": int(data["doc_id"]),
|
||||||
|
"filepath": data.get("filepath", ""),
|
||||||
|
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _labels_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
|
||||||
|
|
||||||
|
def _meta_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||||
|
|
||||||
|
def create_index(self) -> None:
|
||||||
|
if not self._pending_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
embeddings: list[np.ndarray] = []
|
||||||
|
labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
for item in self._pending_items:
|
||||||
|
doc_id = int(item["doc_id"])
|
||||||
|
filepath = item.get("filepath", "")
|
||||||
|
colbert_vecs = item["colbert_vecs"]
|
||||||
|
for seq_id, vec in enumerate(colbert_vecs):
|
||||||
|
vec_np = np.asarray(vec, dtype=np.float32)
|
||||||
|
embeddings.append(vec_np)
|
||||||
|
labels_meta.append(
|
||||||
|
{
|
||||||
|
"id": f"{doc_id}:{seq_id}",
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"seq_id": int(seq_id),
|
||||||
|
"filepath": filepath,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||||
|
# print shape of embeddings_np
|
||||||
|
print(embeddings_np.shape)
|
||||||
|
|
||||||
|
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||||
|
ids = [str(i) for i in range(embeddings_np.shape[0])]
|
||||||
|
builder.build(embeddings_np, ids, self.index_path)
|
||||||
|
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
with open(self._meta_path(), "w", encoding="utf-8") as f:
|
||||||
|
_json.dump(self._meta_dict(), f, indent=2)
|
||||||
|
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||||
|
_json.dump(labels_meta, f)
|
||||||
|
|
||||||
|
self._labels_meta = labels_meta
|
||||||
|
|
||||||
|
def _load_labels_meta_if_needed(self) -> None:
|
||||||
|
if self._labels_meta:
|
||||||
|
return
|
||||||
|
labels_path = self._labels_path()
|
||||||
|
if labels_path.exists():
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
with open(labels_path, encoding="utf-8") as f:
|
||||||
|
self._labels_meta = _json.load(f)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
|
||||||
|
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||||
|
raw = searcher.search(
|
||||||
|
data,
|
||||||
|
first_stage_k,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
complexity=128,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
batch_size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = raw.get("labels")
|
||||||
|
distances = raw.get("distances")
|
||||||
|
if labels is None or distances is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
doc_scores: dict[int, float] = {}
|
||||||
|
B = len(labels)
|
||||||
|
for b in range(B):
|
||||||
|
per_doc_best: dict[int, float] = {}
|
||||||
|
for k, sid in enumerate(labels[b]):
|
||||||
|
try:
|
||||||
|
idx = int(sid)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if 0 <= idx < len(self._labels_meta):
|
||||||
|
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
score = float(distances[b][k])
|
||||||
|
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
|
||||||
|
per_doc_best[doc_id] = score
|
||||||
|
for doc_id, best_score in per_doc_best.items():
|
||||||
|
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
|
||||||
|
|
||||||
|
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||||
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
# pip install pdf2image
|
||||||
|
# pip install pymilvus
|
||||||
|
# pip install colpali_engine
|
||||||
|
# pip install tqdm
|
||||||
|
# pip install pillow
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Ensure local leann packages are importable before importing them
|
||||||
|
_repo_root = Path(__file__).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||||
|
_device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(_device_str)
|
||||||
|
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||||
|
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=_dtype,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"How to end-to-end retrieval with ColBert",
|
||||||
|
"Where is ColBERT performance Table, including text representation results?",
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
qs: list[torch.Tensor] = []
|
||||||
|
for batch_query in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
print(qs[0].shape)
|
||||||
|
# %%
|
||||||
|
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||||
|
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
ds: list[torch.Tensor] = []
|
||||||
|
for batch_doc in tqdm(dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
|
||||||
|
print(ds[0].shape)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Build HNSW index via LeannRetriever primitives and run search
|
||||||
|
index_path = "./indexes/colpali.leann"
|
||||||
|
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||||
|
retriever.create_collection()
|
||||||
|
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||||
|
for i in range(len(filepaths)):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": ds[i].float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
for query in qs:
|
||||||
|
query_np = query.float().numpy()
|
||||||
|
result = retriever.search(query_np, topk=1)
|
||||||
|
print(filepaths[result[0][1]])
|
||||||
@@ -0,0 +1,477 @@
|
|||||||
|
## Jupyter-style notebook script
|
||||||
|
# %%
|
||||||
|
# uv pip install matplotlib qwen_vl_utils
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||||
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Config
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
||||||
|
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||||
|
|
||||||
|
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||||
|
USE_HF_DATASET: bool = True
|
||||||
|
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||||
|
DATASET_SPLIT: str = "train"
|
||||||
|
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||||
|
|
||||||
|
# Local pages (used when USE_HF_DATASET == False)
|
||||||
|
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||||
|
PAGES_DIR: str = "./pages"
|
||||||
|
|
||||||
|
# Index + retrieval settings
|
||||||
|
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||||
|
TOPK: int = 1
|
||||||
|
FIRST_STAGE_K: int = 500
|
||||||
|
REBUILD_INDEX: bool = False
|
||||||
|
|
||||||
|
# Artifacts
|
||||||
|
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||||
|
SIMILARITY_MAP: bool = True
|
||||||
|
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||||
|
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||||
|
ANSWER: bool = True
|
||||||
|
MAX_NEW_TOKENS: int = 128
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Helpers
|
||||||
|
def _natural_sort_key(name: str) -> int:
|
||||||
|
m = re.search(r"\d+", name)
|
||||||
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||||
|
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||||
|
filenames = sorted(filenames, key=_natural_sort_key)
|
||||||
|
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||||
|
images = [Image.open(p) for p in filepaths]
|
||||||
|
return filepaths, images
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||||
|
if not pdf_path:
|
||||||
|
return
|
||||||
|
os.makedirs(pages_dir, exist_ok=True)
|
||||||
|
try:
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||||
|
) from e
|
||||||
|
images = convert_from_path(pdf_path, dpi=dpi)
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||||
|
|
||||||
|
|
||||||
|
def _select_device_and_dtype():
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import get_torch_device
|
||||||
|
|
||||||
|
device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(device_str)
|
||||||
|
# Stable dtype selection to avoid NaNs:
|
||||||
|
# - CUDA: prefer bfloat16 if supported, else float16
|
||||||
|
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||||
|
# - CPU: float32
|
||||||
|
if device_str == "cuda":
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif device_str == "mps":
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
return device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _load_colvision(model_choice: str):
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
|
if model_choice == "colqwen2":
|
||||||
|
model_name = "vidore/colqwen2-v1.0"
|
||||||
|
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||||
|
attn_implementation = (
|
||||||
|
"flash_attention_2"
|
||||||
|
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||||
|
else "eager"
|
||||||
|
)
|
||||||
|
model = ColQwen2.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
).eval()
|
||||||
|
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||||
|
else:
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
return model_name, model, processor, device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Ensure deterministic eval and autocast for stability
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[Image.Image](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_vecs: list[Any] = []
|
||||||
|
for batch_doc in tqdm(dataloader, desc="Embedding images"):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
else:
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
return doc_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
q_vecs: list[Any] = []
|
||||||
|
for batch_query in tqdm(dataloader, desc="Embedding queries"):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
else:
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
return q_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
||||||
|
dim = int(doc_vecs[0].shape[-1])
|
||||||
|
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
retriever.create_collection()
|
||||||
|
for i, vec in enumerate(doc_vecs):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": vec.float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
||||||
|
index_base = Path(index_path)
|
||||||
|
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||||
|
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||||
|
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||||
|
if index_base.exists() and meta.exists() and labels.exists():
|
||||||
|
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_similarity_map(
|
||||||
|
model,
|
||||||
|
processor,
|
||||||
|
image: Image.Image,
|
||||||
|
query: str,
|
||||||
|
token_idx: Optional[int] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> tuple[int, float]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.interpretability import (
|
||||||
|
get_similarity_maps_from_embeddings,
|
||||||
|
plot_similarity_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = processor.process_images([image]).to(model.device)
|
||||||
|
batch_queries = processor.process_queries([query]).to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image_embeddings = model.forward(**batch_images)
|
||||||
|
query_embeddings = model.forward(**batch_queries)
|
||||||
|
|
||||||
|
n_patches = processor.get_n_patches(
|
||||||
|
image_size=image.size,
|
||||||
|
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||||
|
)
|
||||||
|
image_mask = processor.get_image_mask(batch_images)
|
||||||
|
|
||||||
|
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
n_patches=n_patches,
|
||||||
|
image_mask=image_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_maps = batched_similarity_maps[0]
|
||||||
|
|
||||||
|
# Determine token index if not provided: choose the token with highest max score
|
||||||
|
if token_idx is None:
|
||||||
|
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||||
|
token_idx = int(per_token_max.argmax().item())
|
||||||
|
|
||||||
|
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, ax = plot_similarity_map(
|
||||||
|
image=image,
|
||||||
|
similarity_map=similarity_maps[token_idx],
|
||||||
|
figsize=(14, 14),
|
||||||
|
show_colorbar=False,
|
||||||
|
)
|
||||||
|
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
plt.savefig(output_path, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return token_idx, float(max_sim_score)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVL:
|
||||||
|
def __init__(self, device: str):
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
|
||||||
|
min_pixels = 256 * 28 * 28
|
||||||
|
max_pixels = 1280 * 28 * 28
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
|
||||||
|
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for img in images:
|
||||||
|
buffer = BytesIO()
|
||||||
|
img.save(buffer, format="jpeg")
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||||
|
content.append({"type": "text", "text": query})
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = inputs.to(self.model.device)
|
||||||
|
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
return self.processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Step 1: Prepare data
|
||||||
|
if USE_HF_DATASET:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||||
|
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||||
|
filepaths: list[str] = []
|
||||||
|
images: list[Image.Image] = []
|
||||||
|
for i in tqdm(range(N), desc="Loading dataset", total=N ):
|
||||||
|
p = dataset[i]
|
||||||
|
# Compose a descriptive identifier for printing later
|
||||||
|
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||||
|
print(identifier)
|
||||||
|
filepaths.append(identifier)
|
||||||
|
images.append(p["page_image"]) # PIL Image
|
||||||
|
else:
|
||||||
|
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||||
|
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 2: Load model and processor
|
||||||
|
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||||
|
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 3: Build or load index
|
||||||
|
retriever: Optional[LeannMultiVector] = None
|
||||||
|
if not REBUILD_INDEX:
|
||||||
|
try:
|
||||||
|
one_vec = _embed_images(model, processor, [images[0]])[0]
|
||||||
|
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
||||||
|
except Exception:
|
||||||
|
retriever = None
|
||||||
|
|
||||||
|
if retriever is None:
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 4: Embed query and search
|
||||||
|
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||||
|
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||||
|
if not results:
|
||||||
|
print("No results found.")
|
||||||
|
else:
|
||||||
|
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||||
|
top_images: list[Image.Image] = []
|
||||||
|
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||||
|
path = filepaths[doc_id]
|
||||||
|
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||||
|
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||||
|
top_images.append(images[doc_id])
|
||||||
|
|
||||||
|
if SAVE_TOP_IMAGE:
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
base = _Path(SAVE_TOP_IMAGE)
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if base.suffix:
|
||||||
|
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||||
|
else:
|
||||||
|
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||||
|
img.save(str(out_path))
|
||||||
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
|
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 5: Similarity maps for top-K results
|
||||||
|
if results and SIMILARITY_MAP:
|
||||||
|
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if output_base:
|
||||||
|
if output_base.suffix:
|
||||||
|
out_dir = output_base.parent
|
||||||
|
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||||
|
out_path = str(out_dir / out_name)
|
||||||
|
else:
|
||||||
|
out_dir = output_base
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||||
|
else:
|
||||||
|
out_path = None
|
||||||
|
chosen_idx, max_sim = _generate_similarity_map(
|
||||||
|
model=model,
|
||||||
|
processor=processor,
|
||||||
|
image=img,
|
||||||
|
query=QUERY,
|
||||||
|
token_idx=token_idx,
|
||||||
|
output_path=out_path,
|
||||||
|
)
|
||||||
|
if out_path:
|
||||||
|
print(
|
||||||
|
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 6: Optional answer generation
|
||||||
|
if results and ANSWER:
|
||||||
|
qwen = QwenVL(device=device_str)
|
||||||
|
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||||
|
print("\nAnswer:")
|
||||||
|
print(response)
|
||||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
WeChat History RAG example using the unified interface.
|
||||||
|
Supports WeChat chat history export and search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
|
||||||
|
from .history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatRAG(BaseRAGExample):
|
||||||
|
"""RAG example for WeChat chat history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Match original default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="WeChat History",
|
||||||
|
description="Process and query WeChat chat history with LEANN",
|
||||||
|
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add WeChat-specific arguments."""
|
||||||
|
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_export",
|
||||||
|
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||||
|
"""Export WeChat data using wechattweak-cli."""
|
||||||
|
print("Exporting WeChat data...")
|
||||||
|
|
||||||
|
# Check if WeChat is running
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print("WeChat is not running. Please start WeChat first.")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
pass # pgrep might not be available on all systems
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Run export command
|
||||||
|
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Running: {' '.join(cmd)}")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print("WeChat data exported successfully!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"Export failed: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("\nError: wechattweak-cli not found!")
|
||||||
|
print("Please install it first:")
|
||||||
|
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load WeChat history and convert to text chunks."""
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||||
|
# Try to find any existing exports in common locations
|
||||||
|
export_dirs = reader.find_wechat_export_dirs()
|
||||||
|
if not export_dirs:
|
||||||
|
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Load documents from all found export directories
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per export
|
||||||
|
max_per_export = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_export = remaining
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_per_export,
|
||||||
|
concatenate_messages=True, # Enable message concatenation for better context
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks with contact information
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
text_splitter = SentenceSplitter(
|
||||||
|
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
# Add contact information to each chunk
|
||||||
|
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||||
|
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||||
|
print(" You can still query existing exports on other platforms\n")
|
||||||
|
|
||||||
|
# Example queries for WeChat RAG
|
||||||
|
print("\n💬 WeChat History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'Show me conversations about travel plans'")
|
||||||
|
print("- 'Find group chats about weekend activities'")
|
||||||
|
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||||
|
print("- 'What did we discuss about the project last month?'")
|
||||||
|
print("\nNote: WeChat must be running for export to work\n")
|
||||||
|
|
||||||
|
rag = WeChatRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
@@ -1,9 +1,24 @@
|
|||||||
# 🧪 Leann Sanity Checks
|
# 🧪 LEANN Benchmarks & Testing
|
||||||
|
|
||||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||||
|
|
||||||
## 📁 Test Files
|
## 📁 Test Files
|
||||||
|
|
||||||
|
### `diskann_vs_hnsw_speed_comparison.py`
|
||||||
|
Performance comparison between DiskANN and HNSW backends:
|
||||||
|
- ✅ **Search latency** comparison with both backends using recompute
|
||||||
|
- ✅ **Index size** and **build time** measurements
|
||||||
|
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||||
|
- ✅ **Configurable dataset sizes** for different scales
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Quick comparison with 500 docs, 10 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||||
|
|
||||||
|
# Large-scale comparison with 2000 docs, 20 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||||
|
```
|
||||||
|
|
||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
0
benchmarks/__init__.py
Normal file
0
benchmarks/__init__.py
Normal file
@@ -1,10 +1,11 @@
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from mlx_lm import load
|
from mlx_lm import load
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||||
@@ -18,12 +19,14 @@ DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_
|
|||||||
|
|
||||||
# --- Benchmark Functions ---b
|
# --- Benchmark Functions ---b
|
||||||
|
|
||||||
|
|
||||||
def benchmark_torch(model, sentences):
|
def benchmark_torch(model, sentences):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
model.encode(sentences, convert_to_numpy=True)
|
model.encode(sentences, convert_to_numpy=True)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
return (end_time - start_time) * 1000 # Return time in ms
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
def benchmark_mlx(model, tokenizer, sentences):
|
def benchmark_mlx(model, tokenizer, sentences):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -63,6 +66,7 @@ def benchmark_mlx(model, tokenizer, sentences):
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
return (end_time - start_time) * 1000 # Return time in ms
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
# --- Main Execution ---
|
# --- Main Execution ---
|
||||||
def main():
|
def main():
|
||||||
print("--- Initializing Models ---")
|
print("--- Initializing Models ---")
|
||||||
@@ -98,7 +102,9 @@ def main():
|
|||||||
results_torch.append(np.mean(torch_times))
|
results_torch.append(np.mean(torch_times))
|
||||||
|
|
||||||
# Benchmark MLX
|
# Benchmark MLX
|
||||||
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
|
mlx_times = [
|
||||||
|
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||||
|
]
|
||||||
results_mlx.append(np.mean(mlx_times))
|
results_mlx.append(np.mean(mlx_times))
|
||||||
|
|
||||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||||
@@ -109,10 +115,16 @@ def main():
|
|||||||
# --- Plotting ---
|
# --- Plotting ---
|
||||||
print("\n--- Generating Plot ---")
|
print("\n--- Generating Plot ---")
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
|
plt.plot(
|
||||||
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
|
BATCH_SIZES,
|
||||||
|
results_torch,
|
||||||
|
marker="o",
|
||||||
|
linestyle="-",
|
||||||
|
label=f"PyTorch ({device})",
|
||||||
|
)
|
||||||
|
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||||
|
|
||||||
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
|
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||||
plt.xlabel("Batch Size")
|
plt.xlabel("Batch Size")
|
||||||
plt.ylabel("Average Time per Batch (ms)")
|
plt.ylabel("Average Time per Batch (ms)")
|
||||||
plt.xticks(BATCH_SIZES)
|
plt.xticks(BATCH_SIZES)
|
||||||
@@ -124,5 +136,6 @@ def main():
|
|||||||
plt.savefig(output_filename)
|
plt.savefig(output_filename)
|
||||||
print(f"Plot saved to {output_filename}")
|
print(f"Plot saved to {output_filename}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
148
benchmarks/benchmark_no_recompute.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def _meta_exists(index_path: str) -> bool:
|
||||||
|
p = Path(index_path)
|
||||||
|
return (p.parent / f"{p.stem}.meta.json").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
||||||
|
# if _meta_exists(index_path):
|
||||||
|
# return
|
||||||
|
kwargs = {}
|
||||||
|
if backend_name == "hnsw":
|
||||||
|
kwargs["is_compact"] = is_recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
||||||
|
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
num_threads=4,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for i in range(num_docs):
|
||||||
|
builder.add_text(
|
||||||
|
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
||||||
|
)
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _bench_group(
|
||||||
|
index_path: str,
|
||||||
|
recompute: bool,
|
||||||
|
query: str,
|
||||||
|
repeats: int,
|
||||||
|
complexity: int = 32,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> float:
|
||||||
|
# Independent searcher per group; fixed port when recompute
|
||||||
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
|
|
||||||
|
# Warm-up once
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _once() -> float:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
return time.time() - t0
|
||||||
|
|
||||||
|
if repeats <= 1:
|
||||||
|
t = _once()
|
||||||
|
else:
|
||||||
|
vals = [_once() for _ in range(repeats)]
|
||||||
|
vals.sort()
|
||||||
|
t = vals[len(vals) // 2]
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-docs", type=int, default=5000)
|
||||||
|
parser.add_argument("--repeats", type=int, default=3)
|
||||||
|
parser.add_argument("--complexity", type=int, default=32)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# ---------- Build HNSW variants ----------
|
||||||
|
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
||||||
|
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
||||||
|
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
||||||
|
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Build DiskANN variants ----------
|
||||||
|
diskann_r = str(base / "diskann_r.leann")
|
||||||
|
diskann_nr = str(base / "diskann_nr.leann")
|
||||||
|
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
||||||
|
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Helpers ----------
|
||||||
|
def _size_for(prefix: str) -> int:
|
||||||
|
p = Path(prefix)
|
||||||
|
base_dir = p.parent
|
||||||
|
stem = p.stem
|
||||||
|
total = 0
|
||||||
|
for f in base_dir.iterdir():
|
||||||
|
if f.is_file() and f.name.startswith(stem):
|
||||||
|
total += f.stat().st_size
|
||||||
|
return total
|
||||||
|
|
||||||
|
# ---------- HNSW benchmark ----------
|
||||||
|
t_hnsw_r = _bench_group(
|
||||||
|
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_hnsw_nr = _bench_group(
|
||||||
|
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
size_hnsw_r = _size_for(hnsw_r)
|
||||||
|
size_hnsw_nr = _size_for(hnsw_nr)
|
||||||
|
|
||||||
|
print("Benchmark results (HNSW):")
|
||||||
|
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(
|
||||||
|
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
||||||
|
)
|
||||||
|
print(" Expectation: no-recompute should be faster but larger on disk.")
|
||||||
|
|
||||||
|
# ---------- DiskANN benchmark ----------
|
||||||
|
t_diskann_r = _bench_group(
|
||||||
|
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_diskann_nr = _bench_group(
|
||||||
|
diskann_nr,
|
||||||
|
False,
|
||||||
|
"DiskANN NR test doc 123",
|
||||||
|
repeats=args.repeats,
|
||||||
|
complexity=args.complexity,
|
||||||
|
)
|
||||||
|
size_diskann_r = _size_for(diskann_r)
|
||||||
|
size_diskann_nr = _size_for(diskann_nr)
|
||||||
|
|
||||||
|
print("\nBenchmark results (DiskANN):")
|
||||||
|
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
||||||
|
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
BM25 vs DiskANN Baselines
|
||||||
|
|
||||||
|
```bash
|
||||||
|
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
|
||||||
|
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
|
||||||
|
```
|
||||||
|
|
||||||
|
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
|
||||||
|
- Machine-specific; results measured locally with the current repo.
|
||||||
|
|
||||||
|
DiskANN (NQ queries, search-only)
|
||||||
|
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
|
||||||
|
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
|
||||||
|
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
|
||||||
|
|
||||||
|
BM25
|
||||||
|
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
|
||||||
|
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
|
||||||
|
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
|
||||||
|
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.
|
||||||
|
After Width: | Height: | Size: 1.3 KiB |
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "pyserini"
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
# sudo pacman -S jdk21-openjdk
|
||||||
|
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
|
||||||
|
# sudo archlinux-java status
|
||||||
|
# sudo archlinux-java set java-21-openjdk
|
||||||
|
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
|
||||||
|
# fish_add_path --global $JAVA_HOME/bin
|
||||||
|
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
|
||||||
|
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(path: str, limit: int | None) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
|
||||||
|
_, ext = os.path.splitext(path)
|
||||||
|
if ext.lower() in {".jsonl", ".json"}:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Not strict JSONL? treat the whole line as the query
|
||||||
|
queries.append(line)
|
||||||
|
continue
|
||||||
|
q = obj.get("query") or obj.get("text") or obj.get("question")
|
||||||
|
if q:
|
||||||
|
queries.append(str(q))
|
||||||
|
else:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
s = line.strip()
|
||||||
|
if s:
|
||||||
|
queries.append(s)
|
||||||
|
|
||||||
|
if limit is not None and limit > 0:
|
||||||
|
queries = queries[:limit]
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def percentile(values: list[float], p: float) -> float:
|
||||||
|
if not values:
|
||||||
|
return 0.0
|
||||||
|
s = sorted(values)
|
||||||
|
k = (len(s) - 1) * (p / 100.0)
|
||||||
|
f = int(k)
|
||||||
|
c = min(f + 1, len(s) - 1)
|
||||||
|
if f == c:
|
||||||
|
return s[f]
|
||||||
|
return s[f] + (s[c] - s[f]) * (k - f)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--bm25-index",
|
||||||
|
default="benchmarks/data/indices/bm25_index",
|
||||||
|
help="Path to Pyserini Lucene index directory",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--queries",
|
||||||
|
default="benchmarks/data/queries/nq_open.jsonl",
|
||||||
|
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
|
||||||
|
)
|
||||||
|
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
|
||||||
|
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
|
||||||
|
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
|
||||||
|
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
|
||||||
|
)
|
||||||
|
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyserini.search.lucene import LuceneSearcher
|
||||||
|
except Exception:
|
||||||
|
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not os.path.isdir(args.bm25_index):
|
||||||
|
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
queries = load_queries(args.queries, args.limit)
|
||||||
|
if not queries:
|
||||||
|
print("No queries loaded.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||||
|
print(f"Opening BM25 index: {args.bm25_index}")
|
||||||
|
searcher = LuceneSearcher(args.bm25_index)
|
||||||
|
# Some builds of pyserini require explicit set_bm25; others ignore
|
||||||
|
try:
|
||||||
|
searcher.set_bm25(k1=args.k1, b=args.b)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
total_searches = 0
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for i in range(min(args.warmup, len(queries))):
|
||||||
|
_ = searcher.search(queries[i], k=args.k)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
for i, q in enumerate(queries):
|
||||||
|
t1 = time.time()
|
||||||
|
hits = searcher.search(q, k=args.k)
|
||||||
|
t2 = time.time()
|
||||||
|
latencies.append(t2 - t1)
|
||||||
|
total_searches += 1
|
||||||
|
|
||||||
|
if args.fetch_docs:
|
||||||
|
# Optional doc fetch to include I/O time
|
||||||
|
for h in hits:
|
||||||
|
try:
|
||||||
|
_ = searcher.doc(h.docid)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
print(f"Processed {i + 1}/{len(queries)} queries")
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
total_time = t1 - t0
|
||||||
|
|
||||||
|
if latencies:
|
||||||
|
avg = mean(latencies)
|
||||||
|
p50 = percentile(latencies, 50)
|
||||||
|
p90 = percentile(latencies, 90)
|
||||||
|
p95 = percentile(latencies, 95)
|
||||||
|
p99 = percentile(latencies, 99)
|
||||||
|
qps = total_searches / total_time if total_time > 0 else 0.0
|
||||||
|
else:
|
||||||
|
avg = p50 = p90 = p95 = p99 = qps = 0.0
|
||||||
|
|
||||||
|
print("BM25 Latency Report")
|
||||||
|
print(f" queries: {total_searches}")
|
||||||
|
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
|
||||||
|
print(f" avg per query: {avg:.6f} s")
|
||||||
|
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
|
||||||
|
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
|
||||||
|
|
||||||
|
if args.report:
|
||||||
|
payload = {
|
||||||
|
"queries": total_searches,
|
||||||
|
"k": args.k,
|
||||||
|
"k1": args.k1,
|
||||||
|
"b": args.b,
|
||||||
|
"avg_s": avg,
|
||||||
|
"p50_s": p50,
|
||||||
|
"p90_s": p90,
|
||||||
|
"p95_s": p95,
|
||||||
|
"p99_s": p99,
|
||||||
|
"total_time_s": total_time,
|
||||||
|
"qps": qps,
|
||||||
|
"index_dir": os.path.abspath(args.bm25_index),
|
||||||
|
"fetch_docs": bool(args.fetch_docs),
|
||||||
|
}
|
||||||
|
with open(args.report, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(payload, f, indent=2)
|
||||||
|
print(f"Saved report to {args.report}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "leann-backend-diskann"
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(path: Path, limit: int | None) -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
obj = json.loads(line)
|
||||||
|
out.append(obj["query"])
|
||||||
|
if limit and len(out) >= limit:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="DiskANN baseline on real NQ queries (search-only timing)"
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
default="benchmarks/data/indices/diskann_rpj_wiki",
|
||||||
|
help="Directory containing DiskANN files",
|
||||||
|
)
|
||||||
|
ap.add_argument("--index-prefix", default="ann")
|
||||||
|
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
|
||||||
|
ap.add_argument("--num-queries", type=int, default=200)
|
||||||
|
ap.add_argument("--top-k", type=int, default=10)
|
||||||
|
ap.add_argument("--complexity", type=int, default=62)
|
||||||
|
ap.add_argument("--threads", type=int, default=1)
|
||||||
|
ap.add_argument("--beam-width", type=int, default=1)
|
||||||
|
ap.add_argument("--cache-mechanism", type=int, default=2)
|
||||||
|
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
index_dir = Path(args.index_dir).resolve()
|
||||||
|
if not index_dir.is_dir():
|
||||||
|
raise SystemExit(f"Index dir not found: {index_dir}")
|
||||||
|
|
||||||
|
qpath = Path(args.queries_file).resolve()
|
||||||
|
if not qpath.exists():
|
||||||
|
raise SystemExit(f"Queries file not found: {qpath}")
|
||||||
|
|
||||||
|
queries = load_queries(qpath, args.num_queries)
|
||||||
|
print(f"Loaded {len(queries)} queries from {qpath}")
|
||||||
|
|
||||||
|
# Compute embeddings once (exclude from timing)
|
||||||
|
from leann.api import compute_embeddings as _compute
|
||||||
|
|
||||||
|
embs = _compute(
|
||||||
|
queries,
|
||||||
|
model_name="facebook/contriever-msmarco",
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
if embs.ndim != 2:
|
||||||
|
raise SystemExit("Embedding compute failed or returned wrong shape")
|
||||||
|
|
||||||
|
# Build searcher
|
||||||
|
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
|
||||||
|
|
||||||
|
index_prefix_path = str(index_dir / args.index_prefix)
|
||||||
|
searcher = _DiskannSearcher(
|
||||||
|
index_prefix_path,
|
||||||
|
num_threads=int(args.threads),
|
||||||
|
cache_mechanism=int(args.cache_mechanism),
|
||||||
|
num_nodes_to_cache=int(args.num_nodes_to_cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup (not timed)
|
||||||
|
_ = searcher.search(
|
||||||
|
embs[0:1],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
batch_recompute=False,
|
||||||
|
dedup_node_dis=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timed loop
|
||||||
|
times: list[float] = []
|
||||||
|
for i in range(embs.shape[0]):
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
embs[i : i + 1],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
batch_recompute=False,
|
||||||
|
dedup_node_dis=False,
|
||||||
|
)
|
||||||
|
times.append(time.time() - t0)
|
||||||
|
|
||||||
|
times_sorted = sorted(times)
|
||||||
|
avg = float(sum(times) / len(times))
|
||||||
|
p50 = times_sorted[len(times) // 2]
|
||||||
|
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
|
||||||
|
|
||||||
|
print("\nDiskANN (NQ, search-only) Report")
|
||||||
|
print(f" queries: {len(times)}")
|
||||||
|
print(
|
||||||
|
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
|
||||||
|
)
|
||||||
|
print(f" avg per query: {avg:.6f} s")
|
||||||
|
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
|
||||||
|
print(f" QPS: {1.0 / avg:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,14 +3,15 @@
|
|||||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import psutil
|
|
||||||
import gc
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import psutil
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -61,7 +62,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[sys.executable, "examples/faiss_only.py"],
|
[sys.executable, "benchmarks/faiss_only.py"],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300,
|
timeout=300,
|
||||||
@@ -83,9 +84,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if "Peak Memory:" in line:
|
if "Peak Memory:" in line:
|
||||||
peak_memory = float(
|
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"peak_memory": peak_memory}
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
@@ -111,13 +110,12 @@ def test_leann_hnsw():
|
|||||||
|
|
||||||
tracker.checkpoint("After imports")
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
# Load and parse documents
|
# Load and parse documents
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"examples/data",
|
"data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
@@ -135,6 +133,7 @@ def test_leann_hnsw():
|
|||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
print(f"Total number of chunks: {len(all_texts)}")
|
||||||
|
|
||||||
tracker.checkpoint("After text chunking")
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
@@ -201,11 +200,9 @@ def test_leann_hnsw():
|
|||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
tracker.checkpoint("After searcher loading")
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("Running search queries...")
|
print("Running search queries...")
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
@@ -303,21 +300,15 @@ def main():
|
|||||||
|
|
||||||
print("\nLEANN vs Faiss Performance:")
|
print("\nLEANN vs Faiss Performance:")
|
||||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
print(
|
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Storage comparison
|
# Storage comparison
|
||||||
if leann_storage_size > faiss_storage_size:
|
if leann_storage_size > faiss_storage_size:
|
||||||
storage_ratio = leann_storage_size / faiss_storage_size
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
print(
|
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
|
||||||
)
|
|
||||||
elif faiss_storage_size > leann_storage_size:
|
elif faiss_storage_size > leann_storage_size:
|
||||||
storage_ratio = faiss_storage_size / leann_storage_size
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
print(
|
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(" Storage Size: similar")
|
print(" Storage Size: similar")
|
||||||
else:
|
else:
|
||||||
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
DiskANN vs HNSW Search Performance Comparison
|
||||||
|
|
||||||
|
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||||
|
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||||
|
- HNSW: With recompute enabled (is_recompute=True)
|
||||||
|
- Tests performance across different dataset sizes
|
||||||
|
- Measures search latency, recall, and index size
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import multiprocessing as mp
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
||||||
|
try:
|
||||||
|
mp.set_start_method("fork", force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_texts(n_docs: int) -> list[str]:
|
||||||
|
"""Create synthetic test documents for benchmarking."""
|
||||||
|
np.random.seed(42)
|
||||||
|
topics = [
|
||||||
|
"machine learning and artificial intelligence",
|
||||||
|
"natural language processing and text analysis",
|
||||||
|
"computer vision and image recognition",
|
||||||
|
"data science and statistical analysis",
|
||||||
|
"deep learning and neural networks",
|
||||||
|
"information retrieval and search engines",
|
||||||
|
"database systems and data management",
|
||||||
|
"software engineering and programming",
|
||||||
|
"cybersecurity and network protection",
|
||||||
|
"cloud computing and distributed systems",
|
||||||
|
]
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
for i in range(n_docs):
|
||||||
|
topic = topics[i % len(topics)]
|
||||||
|
variation = np.random.randint(1, 100)
|
||||||
|
text = (
|
||||||
|
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||||
|
f"Additional information about {topic} with details and examples. "
|
||||||
|
f"Technical discussion of {topic} including implementation aspects."
|
||||||
|
)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_backend(
|
||||||
|
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Benchmark a specific backend with the given configuration."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Measure index size
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||||
|
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||||
|
size_mb = total_size / (1024 * 1024)
|
||||||
|
|
||||||
|
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||||
|
|
||||||
|
# Search benchmark
|
||||||
|
print("🔍 Running search benchmark...")
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
results = searcher.search(query, top_k=5)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
all_results.append(results)
|
||||||
|
|
||||||
|
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||||
|
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||||
|
|
||||||
|
# Check for valid scores (detect -inf issues)
|
||||||
|
all_scores = [
|
||||||
|
result.score
|
||||||
|
for results in all_results
|
||||||
|
for result in results
|
||||||
|
if result.score is not None
|
||||||
|
]
|
||||||
|
valid_scores = [
|
||||||
|
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||||
|
]
|
||||||
|
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||||
|
|
||||||
|
# Clean up (ensure embedding server shutdown and object GC)
|
||||||
|
try:
|
||||||
|
if hasattr(searcher, "cleanup"):
|
||||||
|
searcher.cleanup()
|
||||||
|
del searcher
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"build_time": build_time,
|
||||||
|
"avg_search_time_ms": avg_search_time,
|
||||||
|
"index_size_mb": size_mb,
|
||||||
|
"score_validity_rate": score_validity_rate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||||
|
"""Run performance comparison between DiskANN and HNSW."""
|
||||||
|
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||||
|
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
texts = create_test_texts(n_docs)
|
||||||
|
test_queries = [
|
||||||
|
"machine learning algorithms",
|
||||||
|
"natural language processing",
|
||||||
|
"computer vision techniques",
|
||||||
|
"data analysis methods",
|
||||||
|
"neural network architectures",
|
||||||
|
"database query optimization",
|
||||||
|
"software development practices",
|
||||||
|
"security vulnerabilities",
|
||||||
|
"cloud infrastructure",
|
||||||
|
"distributed computing",
|
||||||
|
][:n_queries]
|
||||||
|
|
||||||
|
# HNSW benchmark
|
||||||
|
hnsw_results = benchmark_backend(
|
||||||
|
backend_name="hnsw",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable recompute for fair comparison
|
||||||
|
"M": 16,
|
||||||
|
"efConstruction": 200,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# DiskANN benchmark
|
||||||
|
diskann_results = benchmark_backend(
|
||||||
|
backend_name="diskann",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable graph partitioning
|
||||||
|
"num_neighbors": 32,
|
||||||
|
"search_list_size": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
print("\n📈 Performance Comparison Results")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||||
|
print(f"{'-' * 60}")
|
||||||
|
|
||||||
|
# Build time comparison
|
||||||
|
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||||
|
print(
|
||||||
|
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search time comparison
|
||||||
|
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||||
|
print(
|
||||||
|
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index size comparison
|
||||||
|
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||||
|
print(
|
||||||
|
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score validity
|
||||||
|
print(
|
||||||
|
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print("\n🎯 Summary:")
|
||||||
|
if search_speedup > 1:
|
||||||
|
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||||
|
else:
|
||||||
|
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||||
|
|
||||||
|
if size_ratio > 1:
|
||||||
|
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||||
|
else:
|
||||||
|
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle help request
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||||
|
print()
|
||||||
|
print("Arguments:")
|
||||||
|
print(" n_docs Number of documents to index (default: 500)")
|
||||||
|
print(" n_queries Number of test queries to run (default: 10)")
|
||||||
|
print()
|
||||||
|
print("Examples:")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||||
|
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||||
|
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||||
|
print()
|
||||||
|
|
||||||
|
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Benchmark interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Benchmark failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
finally:
|
||||||
|
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
||||||
|
try:
|
||||||
|
gc.collect()
|
||||||
|
print("\n🧹 Cleanup completed")
|
||||||
|
# Flush stdio to ensure message is visible before hard-exit
|
||||||
|
try:
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
|
_sys.stdout.flush()
|
||||||
|
_sys.stderr.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
||||||
|
import os as _os
|
||||||
|
|
||||||
|
_os._exit(0)
|
||||||
141
benchmarks/enron_emails/README.md
Normal file
141
benchmarks/enron_emails/README.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Enron Emails Benchmark
|
||||||
|
|
||||||
|
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
|
||||||
|
|
||||||
|
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||||
|
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||||
|
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
|
||||||
|
|
||||||
|
## Layout
|
||||||
|
|
||||||
|
benchmarks/enron_emails/
|
||||||
|
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||||
|
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
|
||||||
|
- data/: Generated passages, queries, embeddings-related files
|
||||||
|
- baseline/: FAISS Flat baseline files
|
||||||
|
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
1) Prepare the data and index
|
||||||
|
|
||||||
|
cd benchmarks/enron_emails
|
||||||
|
python setup_enron_emails.py --data-dir data
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||||
|
Alternatively, pass a local path to `--emails-csv`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||||
|
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||||
|
|
||||||
|
2) Run recall evaluation (Stage 2)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||||
|
|
||||||
|
3) Complexity sweep (Stage 3)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||||
|
|
||||||
|
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||||
|
|
||||||
|
4) Index comparison (Stage 4)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
|
||||||
|
|
||||||
|
5) Generation evaluation (Stage 5)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
|
||||||
|
|
||||||
|
6) Combined index + generation evaluation (Stages 4+5, recommended)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||||
|
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
|
||||||
|
- `--baseline-dir` defaults to `baseline`
|
||||||
|
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||||
|
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
|
||||||
|
- `--model-name` defaults to `Qwen/Qwen3-8B`
|
||||||
|
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||||
|
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||||
|
- --baseline-dir baseline (where FAISS baseline lives)
|
||||||
|
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
|
||||||
|
- --llm-backend hf|vllm (LLM backend for generation)
|
||||||
|
- --model-name Qwen/Qwen3-8B (LLM model for generation)
|
||||||
|
- --max-queries 1000 (limit number of queries for evaluation)
|
||||||
|
|
||||||
|
## Files Produced
|
||||||
|
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||||
|
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||||
|
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||||
|
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
|
||||||
|
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||||
|
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
|
||||||
|
|
||||||
|
## Stages Summary
|
||||||
|
|
||||||
|
- Stage 2 (Recall@3):
|
||||||
|
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||||
|
- Compact index runs with `recompute_embeddings=True`.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search for Complexity):
|
||||||
|
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||||
|
|
||||||
|
- Stage 4 (Index Comparison):
|
||||||
|
- Reports .index-only sizes for compact vs non-compact.
|
||||||
|
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||||
|
- Stores retrieval results for Stage 5 generation evaluation.
|
||||||
|
- Fails fast if compact recompute cannot run.
|
||||||
|
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||||
|
- First from the current run (when running `--stage all`), otherwise
|
||||||
|
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||||
|
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||||
|
|
||||||
|
- Stage 5 (Generation Evaluation):
|
||||||
|
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
|
||||||
|
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
|
||||||
|
- Measures generation timing separately from search timing.
|
||||||
|
- Requires Stage 4 results (no additional searching performed).
|
||||||
|
|
||||||
|
## Example Results
|
||||||
|
|
||||||
|
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search):
|
||||||
|
- Minimal complexity achieving 90% Recall@3: 88
|
||||||
|
- Sampled points:
|
||||||
|
- C=8 → 59.9% Recall@3
|
||||||
|
- C=72 → 89.4% Recall@3
|
||||||
|
- C=88 → 90.2% Recall@3
|
||||||
|
- C=96 → 90.7% Recall@3
|
||||||
|
- C=112 → 91.1% Recall@3
|
||||||
|
- C=136 → 91.3% Recall@3
|
||||||
|
- C=256 → 92.0% Recall@3
|
||||||
|
|
||||||
|
- Stage 4 (Index Sizes, .index only):
|
||||||
|
- Compact: ~2.2 MB
|
||||||
|
- Non-compact: ~82.0 MB
|
||||||
|
- Storage saving by compact: ~97.3%
|
||||||
|
|
||||||
|
- Stage 4 (Search Timing, 988 queries, complexity=88):
|
||||||
|
- Non-compact (no recompute): ~0.0075 s avg per query
|
||||||
|
- Compact (with recompute): ~1.981 s avg per query
|
||||||
|
- Speed ratio (non-compact/compact): ~0.0038x
|
||||||
|
|
||||||
|
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
|
||||||
|
- Average generation time: ~22.302 s per query
|
||||||
|
- Total queries processed: 988
|
||||||
|
- LLM backend: HuggingFace transformers
|
||||||
|
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||||
|
|
||||||
|
Full JSON output is saved by the script (see `--output`), e.g.:
|
||||||
|
`benchmarks/enron_emails/results_enron_stage45.json`.
|
||||||
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
downloads/
|
||||||
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
@@ -0,0 +1,614 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||||
|
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||||
|
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||||
|
On errors, fail fast without fallbacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Compute query embedding with the same model/mode as the index
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS Flat ground truth
|
||||||
|
n = q_emb.shape[0]
|
||||||
|
k = 3
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(q_emb),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN (may require embedding server depending on index configuration)
|
||||||
|
results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {r.id for r in results}
|
||||||
|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN: {list(test_ids)}")
|
||||||
|
print(f" ∩: {list(intersection)}")
|
||||||
|
|
||||||
|
avg = total_recall / max(1, len(queries))
|
||||||
|
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||||
|
return avg
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class EnronEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
if "query" in data:
|
||||||
|
queries.append(data["query"])
|
||||||
|
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||||
|
|
||||||
|
sizes["index_only_mb"] = (
|
||||||
|
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source and embedding model
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
# Load all passages and ids
|
||||||
|
ids: list[str] = []
|
||||||
|
texts: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
texts.append(data["text"])
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
meta["embedding_model"],
|
||||||
|
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False,
|
||||||
|
is_compact=False,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare search speed for non-compact vs compact indexes."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
results: dict = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||||
|
}
|
||||||
|
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
# Non-compact (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Compact (with recompute). Fail fast if it cannot run.
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
docs = compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
results["compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Store retrieval results for Stage 5
|
||||||
|
results["retrieval_results"].append(
|
||||||
|
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||||
|
)
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
if results["non_compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
if results["compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
if results["avg_search_times"].get("compact", 0) > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = 0.0
|
||||||
|
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_complexity(
|
||||||
|
self,
|
||||||
|
recall_eval: "RecallEvaluator",
|
||||||
|
queries: list[str],
|
||||||
|
target: float = 0.90,
|
||||||
|
c_min: int = 8,
|
||||||
|
c_max: int = 256,
|
||||||
|
max_iters: int = 10,
|
||||||
|
recompute: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||||
|
|
||||||
|
def round_c(x: int) -> int:
|
||||||
|
# snap to multiple of 8 like other benches typically do
|
||||||
|
return max(1, int((x + 7) // 8) * 8)
|
||||||
|
|
||||||
|
metrics: list[dict] = []
|
||||||
|
|
||||||
|
lo = round_c(c_min)
|
||||||
|
hi = round_c(c_max)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||||
|
r_lo = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=lo, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
cap = 1024
|
||||||
|
while r_hi < target and hi < cap:
|
||||||
|
lo = hi
|
||||||
|
r_lo = r_hi
|
||||||
|
hi = round_c(hi * 2)
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
if r_hi < target:
|
||||||
|
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||||
|
print("📈 Observations:")
|
||||||
|
for m in metrics:
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||||
|
|
||||||
|
# Binary search within [lo, hi]
|
||||||
|
best = hi
|
||||||
|
iters = 0
|
||||||
|
while lo < hi and iters < max_iters:
|
||||||
|
mid = round_c((lo + hi) // 2)
|
||||||
|
r_mid = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=mid, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||||
|
if r_mid >= target:
|
||||||
|
best = mid
|
||||||
|
hi = mid
|
||||||
|
else:
|
||||||
|
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||||
|
iters += 1
|
||||||
|
|
||||||
|
print("📈 Binary search results (sampled points):")
|
||||||
|
# Print unique complexity entries ordered by complexity
|
||||||
|
for m in sorted(
|
||||||
|
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||||
|
):
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||||
|
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "5", "all", "45"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||||
|
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Resolve queries file: if default path not found, fall back to index's directory
|
||||||
|
if not os.path.exists(args.queries):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
idx_dir = Path(args.index).parent
|
||||||
|
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||||
|
if fallback_q.exists():
|
||||||
|
args.queries = str(fallback_q)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
results_out: dict = {}
|
||||||
|
|
||||||
|
if args.stage in ("2", "all"):
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[:10]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||||
|
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("3", "all"):
|
||||||
|
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[: args.max_queries]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact evaluator for binary search with recompute=False
|
||||||
|
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
sweep = enron_eval.evaluate_complexity(
|
||||||
|
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||||
|
)
|
||||||
|
results_out["stage3"] = sweep
|
||||||
|
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"stage3": sweep}, f, indent=2)
|
||||||
|
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||||
|
evaluator_nc.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 3 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("4", "all", "45"):
|
||||||
|
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
test_q = queries[: min(args.max_queries, len(queries))]
|
||||||
|
|
||||||
|
current_sizes = enron_eval.analyze_index_sizes()
|
||||||
|
# Build non-compact index for comparison (no fallback)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||||
|
nc_eval = EnronEvaluator(non_compact_path)
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_sizes.get("index_only_mb", 0) > 0
|
||||||
|
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||||
|
):
|
||||||
|
storage_saving_percent = max(
|
||||||
|
0.0,
|
||||||
|
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_saving_percent = 0.0
|
||||||
|
|
||||||
|
if args.complexity is None:
|
||||||
|
# Prefer in-session Stage 3 result
|
||||||
|
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||||
|
complexity = results_out["stage3"]["best_complexity"]
|
||||||
|
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||||
|
else:
|
||||||
|
# Try to load last saved Stage 3 result near index
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
if default_stage3_path.exists():
|
||||||
|
with open(default_stage3_path, encoding="utf-8") as f:
|
||||||
|
prev = json.load(f)
|
||||||
|
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||||
|
if complexity is None:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
||||||
|
)
|
||||||
|
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||||
|
else:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
comp = enron_eval.compare_index_performance(
|
||||||
|
non_compact_path, args.index, test_q, complexity=complexity
|
||||||
|
)
|
||||||
|
results_out["stage4"] = {
|
||||||
|
"current_index": current_sizes,
|
||||||
|
"non_compact_index": non_compact_sizes,
|
||||||
|
"storage_saving_percent": storage_saving_percent,
|
||||||
|
"performance_comparison": comp,
|
||||||
|
}
|
||||||
|
nc_eval.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("5", "all"):
|
||||||
|
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||||
|
|
||||||
|
# Check if Stage 4 results exist
|
||||||
|
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||||
|
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||||
|
print("💡 Run Stage 4 first or use --stage all")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||||
|
if not retrieval_results:
|
||||||
|
print("❌ No retrieval results found from Stage 4")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||||
|
|
||||||
|
# Load LLM
|
||||||
|
try:
|
||||||
|
if args.llm_backend == "hf":
|
||||||
|
tokenizer, model = load_hf_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_hf(tokenizer, model, prompt)
|
||||||
|
else: # vllm
|
||||||
|
llm, sampling_params = load_vllm_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_vllm(llm, sampling_params, prompt)
|
||||||
|
|
||||||
|
# Run generation using stored retrieval results
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_utils import create_prompt
|
||||||
|
|
||||||
|
generation_times = []
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
print("🤖 Running generation on pre-retrieved results...")
|
||||||
|
for i, item in enumerate(retrieval_results):
|
||||||
|
query = item["query"]
|
||||||
|
retrieved_docs = item["retrieved_docs"]
|
||||||
|
|
||||||
|
# Prepare context from retrieved docs
|
||||||
|
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||||
|
prompt = create_prompt(context, query, "emails")
|
||||||
|
|
||||||
|
# Time generation only
|
||||||
|
gen_start = time.time()
|
||||||
|
response = llm_func(prompt)
|
||||||
|
gen_time = time.time() - gen_start
|
||||||
|
|
||||||
|
generation_times.append(gen_time)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||||
|
|
||||||
|
print("\n📊 Generation Results:")
|
||||||
|
print(f" Total Queries: {len(retrieval_results)}")
|
||||||
|
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||||
|
print(" (Search time from Stage 4)")
|
||||||
|
|
||||||
|
results_out["stage5"] = {
|
||||||
|
"total_queries": len(retrieval_results),
|
||||||
|
"avg_generation_time": avg_gen_time,
|
||||||
|
"generation_times": generation_times,
|
||||||
|
"responses": responses,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Show sample results
|
||||||
|
print("\n📝 Sample Results:")
|
||||||
|
for i in range(min(3, len(retrieval_results))):
|
||||||
|
query = retrieval_results[i]["query"]
|
||||||
|
response = responses[i]
|
||||||
|
print(f" Q{i + 1}: {query[:60]}...")
|
||||||
|
print(f" A{i + 1}: {response[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Generation evaluation failed: {e}")
|
||||||
|
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||||
|
|
||||||
|
print("✅ Stage 5 completed!\n")
|
||||||
|
|
||||||
|
if args.output and results_out:
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results_out, f, indent=2)
|
||||||
|
print(f"📝 Saved results to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Setup Script
|
||||||
|
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from email import message_from_string
|
||||||
|
from email.policy import default
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class EnronSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||||
|
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||||
|
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
self.downloads_dir = self.data_dir / "downloads"
|
||||||
|
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Dataset acquisition
|
||||||
|
# ----------------------------
|
||||||
|
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||||
|
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||||
|
if emails_csv:
|
||||||
|
p = Path(emails_csv)
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||||
|
|
||||||
|
api = KaggleApi()
|
||||||
|
api.authenticate()
|
||||||
|
api.dataset_download_files(
|
||||||
|
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||||
|
)
|
||||||
|
candidate = self.downloads_dir / "emails.csv"
|
||||||
|
if candidate.exists():
|
||||||
|
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||||
|
return str(candidate)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Data preparation
|
||||||
|
# ----------------------------
|
||||||
|
@staticmethod
|
||||||
|
def _extract_message_id(raw_email: str) -> str:
|
||||||
|
msg = message_from_string(raw_email, policy=default)
|
||||||
|
val = msg.get("Message-ID", "")
|
||||||
|
if val.startswith("<") and val.endswith(">"):
|
||||||
|
val = val[1:-1]
|
||||||
|
return val or ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||||
|
parts = raw_email.split("\n\n", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[0].strip(), parts[1].strip()
|
||||||
|
# Heuristic fallback
|
||||||
|
first_lines = raw_email.splitlines()
|
||||||
|
if first_lines and ":" in first_lines[0]:
|
||||||
|
return raw_email.strip(), ""
|
||||||
|
return "", raw_email.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||||
|
text = (text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
if chunk_words <= 0:
|
||||||
|
return [text]
|
||||||
|
words = text.split()
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
limit = len(words)
|
||||||
|
if not keep_last:
|
||||||
|
limit = (len(words) // chunk_words) * chunk_words
|
||||||
|
if limit == 0:
|
||||||
|
return []
|
||||||
|
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||||
|
return [c for c in (s.strip() for s in chunks) if c]
|
||||||
|
|
||||||
|
def _iter_passages_from_csv(
|
||||||
|
self,
|
||||||
|
emails_csv: Path,
|
||||||
|
chunk_words: int = 256,
|
||||||
|
keep_last_header: bool = True,
|
||||||
|
keep_last_body: bool = True,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> Iterable[dict]:
|
||||||
|
with open(emails_csv, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
count = 0
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if max_emails is not None and count >= max_emails:
|
||||||
|
break
|
||||||
|
|
||||||
|
raw_message = row.get("message", "")
|
||||||
|
email_file_id = row.get("file", "")
|
||||||
|
|
||||||
|
if not raw_message.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
message_id = self._extract_message_id(raw_message)
|
||||||
|
if not message_id:
|
||||||
|
# Fallback ID based on CSV position and file path
|
||||||
|
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||||
|
message_id = f"enron_{i}_{safe_file}"
|
||||||
|
|
||||||
|
header, body = self._split_header_body(raw_message)
|
||||||
|
|
||||||
|
# Header chunks
|
||||||
|
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": True,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Body chunks
|
||||||
|
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": False,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Build LEANN index and FAISS baseline
|
||||||
|
# ----------------------------
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
emails_csv: Optional[str],
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
chunk_words: int = 256,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||||
|
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream passages and add to builder
|
||||||
|
preview_written = 0
|
||||||
|
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||||
|
for p in self._iter_passages_from_csv(
|
||||||
|
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||||
|
):
|
||||||
|
builder.add_text(p["text"], metadata=p["metadata"])
|
||||||
|
if preview_written < 200:
|
||||||
|
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||||
|
preview_written += 1
|
||||||
|
|
||||||
|
print(f"🔨 Building index at {self.index_path}...")
|
||||||
|
builder.build_index(str(self.index_path))
|
||||||
|
print("✅ LEANN index built!")
|
||||||
|
return str(self.index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||||
|
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read meta for passage source and embedding model
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
# Load passages from builder output so IDs match LEANN
|
||||||
|
passages: list[str] = []
|
||||||
|
passage_ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"]) # builder-assigned ID
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||||
|
print(f"🤖 Embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS IndexFlatIP
|
||||||
|
dim = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dim)
|
||||||
|
emb_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||||
|
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as pf:
|
||||||
|
pickle.dump(passage_ids, pf)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved: {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Queries (optional): prepare evaluation queries file
|
||||||
|
# ----------------------------
|
||||||
|
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||||
|
print(
|
||||||
|
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to load dataset: {e}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
kept = 0
|
||||||
|
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||||
|
for i, item in enumerate(ds):
|
||||||
|
how_realistic = float(item.get("how_realistic", 0.0))
|
||||||
|
if how_realistic < min_realism:
|
||||||
|
continue
|
||||||
|
qid = str(item.get("id", f"enron_q_{i}"))
|
||||||
|
query = item.get("question", "")
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
record = {
|
||||||
|
"id": qid,
|
||||||
|
"query": query,
|
||||||
|
# For reference only, not used in recall metric below
|
||||||
|
"gt_message_ids": item.get("message_ids", []),
|
||||||
|
}
|
||||||
|
out.write(json.dumps(record) + "\n")
|
||||||
|
kept += 1
|
||||||
|
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--emails-csv",
|
||||||
|
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model for LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||||
|
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||||
|
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
setup = EnronSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
emails_csv=args.emails_csv,
|
||||||
|
backend=args.backend,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
chunk_words=args.chunk_words,
|
||||||
|
max_emails=args.max_emails,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS baseline from the same passages & embeddings
|
||||||
|
setup.build_faiss_flat_baseline(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LEANN index build and baseline")
|
||||||
|
|
||||||
|
# Queries file (optional)
|
||||||
|
if not args.skip_queries:
|
||||||
|
setup.prepare_queries()
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping query preparation")
|
||||||
|
|
||||||
|
print("\n🎉 Enron Emails setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("Next steps:")
|
||||||
|
print(
|
||||||
|
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Test only Faiss HNSW"""
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import gc
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
def get_memory_usage():
|
||||||
@@ -37,20 +37,20 @@ def main():
|
|||||||
import faiss
|
import faiss
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Faiss is not installed.")
|
print("Faiss is not installed.")
|
||||||
print("Please install it with `uv pip install faiss-cpu`")
|
print(
|
||||||
|
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from llama_index.core import (
|
from llama_index.core import (
|
||||||
SimpleDirectoryReader,
|
|
||||||
VectorStoreIndex,
|
|
||||||
StorageContext,
|
|
||||||
Settings,
|
Settings,
|
||||||
node_parser,
|
SimpleDirectoryReader,
|
||||||
Document,
|
StorageContext,
|
||||||
|
VectorStoreIndex,
|
||||||
)
|
)
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
tracker = MemoryTracker("Faiss HNSW")
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
tracker.checkpoint("Initial")
|
tracker.checkpoint("Initial")
|
||||||
@@ -65,7 +65,7 @@ def main():
|
|||||||
tracker.checkpoint("After Faiss index creation")
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"examples/data",
|
"data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
@@ -90,8 +90,9 @@ def main():
|
|||||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
)
|
)
|
||||||
from llama_index.core import load_index_from_storage
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
index = load_index_from_storage(storage_context=storage_context)
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
print(f"Index loaded from ./storage_faiss")
|
print("Index loaded from ./storage_faiss")
|
||||||
tracker.checkpoint("After loading existing index")
|
tracker.checkpoint("After loading existing index")
|
||||||
index_loaded = True
|
index_loaded = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -99,6 +100,7 @@ def main():
|
|||||||
print("Cleaning up corrupted index and building new one...")
|
print("Cleaning up corrupted index and building new one...")
|
||||||
# Clean up corrupted index
|
# Clean up corrupted index
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if os.path.exists("./storage_faiss"):
|
if os.path.exists("./storage_faiss"):
|
||||||
shutil.rmtree("./storage_faiss")
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
@@ -109,9 +111,7 @@ def main():
|
|||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents,
|
documents, storage_context=storage_context, transformations=[node_parser]
|
||||||
storage_context=storage_context,
|
|
||||||
transformations=[node_parser]
|
|
||||||
)
|
)
|
||||||
tracker.checkpoint("After index building")
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
@@ -127,7 +127,7 @@ def main():
|
|||||||
|
|
||||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
115
benchmarks/financebench/README.md
Normal file
115
benchmarks/financebench/README.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# FinanceBench Benchmark for LEANN-RAG
|
||||||
|
|
||||||
|
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||||
|
- **Questions**: 150 financial Q&A examples
|
||||||
|
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||||
|
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||||
|
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
benchmarks/financebench/
|
||||||
|
├── setup_financebench.py # Downloads PDFs and builds index
|
||||||
|
├── evaluate_financebench.py # Intelligent evaluation script
|
||||||
|
├── data/
|
||||||
|
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||||
|
│ ├── pdfs/ # Downloaded financial documents
|
||||||
|
│ └── index/ # LEANN indexes
|
||||||
|
│ └── financebench_full_hnsw.leann
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### 1. Setup (Download & Build Index)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd benchmarks/financebench
|
||||||
|
python setup_financebench.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Download the 150 Q&A examples
|
||||||
|
- Download all 368 PDF documents (parallel processing)
|
||||||
|
- Build a LEANN index from 53K+ text chunks
|
||||||
|
- Verify setup with test query
|
||||||
|
|
||||||
|
### 2. Evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic retrieval evaluation
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||||
|
|
||||||
|
|
||||||
|
# RAG generation evaluation with Qwen3-8B
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Methods
|
||||||
|
|
||||||
|
### Retrieval Evaluation
|
||||||
|
Uses intelligent matching with three strategies:
|
||||||
|
1. **Exact text overlap** - Direct substring matches
|
||||||
|
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||||
|
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||||
|
|
||||||
|
### QA Evaluation
|
||||||
|
LLM-based answer evaluation using GPT-4o:
|
||||||
|
- Handles numerical rounding and equivalent representations
|
||||||
|
- Considers fractions, percentages, and decimal equivalents
|
||||||
|
- Evaluates semantic meaning rather than exact text match
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
|
||||||
|
|
||||||
|
**Retrieval Metrics:**
|
||||||
|
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
|
||||||
|
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
|
||||||
|
- **Number Match Rate**: 120.7% (key financial figures matched)*
|
||||||
|
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
|
||||||
|
- **Average Search Time**: 0.097s
|
||||||
|
|
||||||
|
**QA Metrics:**
|
||||||
|
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
|
||||||
|
- **Average QA Time**: 4.71s (end-to-end response time)
|
||||||
|
|
||||||
|
**System Performance:**
|
||||||
|
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||||
|
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||||
|
|
||||||
|
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
|
||||||
|
|
||||||
|
### LEANN-RAG Generation Performance (Qwen3-8B)
|
||||||
|
|
||||||
|
- **Stage 4 (Index Comparison):**
|
||||||
|
- Compact Index: 5.0 MB
|
||||||
|
- Non-compact Index: 172.2 MB
|
||||||
|
- **Storage Saving**: 97.1%
|
||||||
|
- **Search Performance**:
|
||||||
|
- Non-compact (no recompute): 0.009s avg per query
|
||||||
|
- Compact (with recompute): 2.203s avg per query
|
||||||
|
- Speed ratio: 0.004x
|
||||||
|
|
||||||
|
**Generation Evaluation (20 queries, complexity=64):**
|
||||||
|
- **Average Search Time**: 1.638s per query
|
||||||
|
- **Average Generation Time**: 45.957s per query
|
||||||
|
- **LLM Backend**: HuggingFace transformers
|
||||||
|
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||||
|
- **Total Questions Processed**: 20
|
||||||
|
|
||||||
|
## Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use different backends
|
||||||
|
python setup_financebench.py --backend diskann
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||||
|
|
||||||
|
# Use different embedding models
|
||||||
|
python setup_financebench.py --embedding-model facebook/contriever
|
||||||
|
```
|
||||||
923
benchmarks/financebench/evaluate_financebench.py
Executable file
923
benchmarks/financebench/evaluate_financebench.py
Executable file
@@ -0,0 +1,923 @@
|
|||||||
|
"""
|
||||||
|
FinanceBench Evaluation Script - Modular Recall-based Evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
from leann import LeannChat, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Load FAISS flat baseline
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 for given queries at specified complexity"""
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = len(queries)
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Get ground truth: search with FAISS flat
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
query_embedding = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||||
|
n = query_embedding.shape[0] # Number of queries
|
||||||
|
k = 3 # Number of nearest neighbors
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(query_embedding),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the results
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN at specified complexity
|
||||||
|
test_results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {result.id for result in test_results}
|
||||||
|
|
||||||
|
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class FinanceBenchEvaluator:
|
||||||
|
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
|
||||||
|
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||||
|
|
||||||
|
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||||
|
"""Load FinanceBench dataset"""
|
||||||
|
data = []
|
||||||
|
with open(dataset_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data.append(json.loads(line))
|
||||||
|
|
||||||
|
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes with and without embeddings"""
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes...")
|
||||||
|
|
||||||
|
# Get all index-related files
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem # Remove .leann extension
|
||||||
|
|
||||||
|
sizes = {}
|
||||||
|
total_with_embeddings = 0
|
||||||
|
|
||||||
|
# Core index files
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||||
|
|
||||||
|
for file_path, name in [
|
||||||
|
(index_file, "index"),
|
||||||
|
(meta_file, "metadata"),
|
||||||
|
(passages_file, "passages_text"),
|
||||||
|
(passages_idx_file, "passages_index"),
|
||||||
|
]:
|
||||||
|
if file_path.exists():
|
||||||
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
|
sizes[name] = size_mb
|
||||||
|
total_with_embeddings += size_mb
|
||||||
|
|
||||||
|
else:
|
||||||
|
sizes[name] = 0
|
||||||
|
|
||||||
|
sizes["total_with_embeddings"] = total_with_embeddings
|
||||||
|
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
|
||||||
|
|
||||||
|
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
|
||||||
|
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
|
||||||
|
"""Create a compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Build compact index with same passages
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||||
|
is_compact=True, # Enable compact storage
|
||||||
|
**meta.get("backend_kwargs", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all passages
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||||
|
|
||||||
|
print(f"🔨 Building compact index at {compact_index_path}...")
|
||||||
|
builder.build_index(compact_index_path)
|
||||||
|
|
||||||
|
# Analyze the compact index size
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
|
||||||
|
compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
compact_sizes["index_type"] = "compact"
|
||||||
|
|
||||||
|
return compact_sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building non-compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Build non-compact index with same passages
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False, # Disable recompute (store embeddings)
|
||||||
|
is_compact=False, # Disable compact storage
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all passages
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||||
|
|
||||||
|
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
|
||||||
|
builder.build_index(non_compact_index_path)
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare performance between non-compact and compact indexes"""
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Test queries
|
||||||
|
test_queries = [item["question"] for item in test_data[:5]]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test non-compact index (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
query, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["non_compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Test compact index (with recompute)
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
query, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance ratio
|
||||||
|
if results["avg_search_times"]["compact"] > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||||
|
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_timing_breakdown(
|
||||||
|
self, data: list[dict], max_samples: Optional[int] = None
|
||||||
|
) -> dict:
|
||||||
|
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
|
||||||
|
if not self.chat or not self.openai_client:
|
||||||
|
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
|
||||||
|
return {
|
||||||
|
"total_questions": 0,
|
||||||
|
"avg_search_time": 0.0,
|
||||||
|
"avg_generation_time": 0.0,
|
||||||
|
"avg_total_time": 0.0,
|
||||||
|
"accuracy": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
|
||||||
|
|
||||||
|
if max_samples:
|
||||||
|
data = data[:max_samples]
|
||||||
|
print(f"📝 Using first {max_samples} samples for timing evaluation")
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
generation_times = []
|
||||||
|
total_times = []
|
||||||
|
correct_answers = 0
|
||||||
|
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
question = item["question"]
|
||||||
|
ground_truth = item["answer"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Hack: Monkey-patch the ask method to capture internal timing
|
||||||
|
original_ask = self.chat.ask
|
||||||
|
captured_search_time = None
|
||||||
|
captured_generation_time = None
|
||||||
|
|
||||||
|
def patched_ask(*args, **kwargs):
|
||||||
|
nonlocal captured_search_time, captured_generation_time
|
||||||
|
|
||||||
|
# Time the search part
|
||||||
|
search_start = time.time()
|
||||||
|
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
|
||||||
|
captured_search_time = time.time() - search_start
|
||||||
|
|
||||||
|
# Time the generation part
|
||||||
|
context = "\n\n".join([r.text for r in results])
|
||||||
|
prompt = (
|
||||||
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
|
f"{context}\n\n"
|
||||||
|
f"Question: {args[0]}\n\n"
|
||||||
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_start = time.time()
|
||||||
|
answer = self.chat.llm.ask(prompt)
|
||||||
|
captured_generation_time = time.time() - generation_start
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
self.chat.ask = patched_ask
|
||||||
|
|
||||||
|
# Time the total QA
|
||||||
|
total_start = time.time()
|
||||||
|
generated_answer = self.chat.ask(question)
|
||||||
|
total_time = time.time() - total_start
|
||||||
|
|
||||||
|
# Restore original method
|
||||||
|
self.chat.ask = original_ask
|
||||||
|
|
||||||
|
# Store the timings
|
||||||
|
search_times.append(captured_search_time)
|
||||||
|
generation_times.append(captured_generation_time)
|
||||||
|
total_times.append(total_time)
|
||||||
|
|
||||||
|
# Check accuracy using LLM as judge
|
||||||
|
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
|
||||||
|
if is_correct:
|
||||||
|
correct_answers += 1
|
||||||
|
|
||||||
|
status = "✅" if is_correct else "❌"
|
||||||
|
print(
|
||||||
|
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
|
||||||
|
)
|
||||||
|
print(f" GT: {ground_truth}")
|
||||||
|
print(f" Gen: {generated_answer[:100]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Error: {e}")
|
||||||
|
search_times.append(0.0)
|
||||||
|
generation_times.append(0.0)
|
||||||
|
total_times.append(0.0)
|
||||||
|
|
||||||
|
accuracy = correct_answers / len(data) if data else 0.0
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"total_questions": len(data),
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
|
||||||
|
"avg_generation_time": sum(generation_times) / len(generation_times)
|
||||||
|
if generation_times
|
||||||
|
else 0.0,
|
||||||
|
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"correct_answers": correct_answers,
|
||||||
|
"search_times": search_times,
|
||||||
|
"generation_times": generation_times,
|
||||||
|
"total_times": total_times,
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _check_answer_accuracy(
|
||||||
|
self, generated_answer: str, ground_truth: str, question: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if generated answer matches ground truth using LLM as judge"""
|
||||||
|
judge_prompt = f"""You are an expert judge evaluating financial question answering.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
|
||||||
|
Ground Truth Answer: {ground_truth}
|
||||||
|
|
||||||
|
Generated Answer: {generated_answer}
|
||||||
|
|
||||||
|
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
|
||||||
|
1. Numerical accuracy (exact values, units, currency)
|
||||||
|
2. Key financial concepts and terminology
|
||||||
|
3. Overall factual correctness
|
||||||
|
|
||||||
|
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
|
||||||
|
|
||||||
|
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
judge_response = self.openai_client.chat.completions.create(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
messages=[{"role": "user", "content": judge_prompt}],
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
judgment = judge_response.choices[0].message.content.strip().upper()
|
||||||
|
return judgment == "CORRECT"
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Judge error: {e}, falling back to string matching")
|
||||||
|
# Fallback to simple string matching
|
||||||
|
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
|
||||||
|
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
|
||||||
|
return gt_clean in gen_clean
|
||||||
|
|
||||||
|
def _print_results(self, timing_metrics: dict):
|
||||||
|
"""Print evaluation results"""
|
||||||
|
print("\n🎯 EVALUATION RESULTS")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Index comparison analysis
|
||||||
|
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||||
|
print("\n📏 Index Comparison Analysis:")
|
||||||
|
current = timing_metrics["current_index"]
|
||||||
|
non_compact = timing_metrics["non_compact_index"]
|
||||||
|
|
||||||
|
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
|
||||||
|
print(
|
||||||
|
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(" Component breakdown (non-compact):")
|
||||||
|
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||||
|
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||||
|
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||||
|
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
if "performance_comparison" in timing_metrics:
|
||||||
|
perf = timing_metrics["performance_comparison"]
|
||||||
|
print("\n⚡ Performance Comparison:")
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
# Legacy single index analysis (fallback)
|
||||||
|
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||||
|
print("\n📏 Index Size Analysis:")
|
||||||
|
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
|
||||||
|
|
||||||
|
print("\n📊 Accuracy:")
|
||||||
|
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
|
||||||
|
print(
|
||||||
|
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Timing Breakdown:")
|
||||||
|
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||||
|
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||||
|
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||||
|
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
|
||||||
|
|
||||||
|
if timing_metrics.get("avg_total_time", 0) > 0:
|
||||||
|
search_pct = (
|
||||||
|
timing_metrics.get("avg_search_time", 0)
|
||||||
|
/ timing_metrics.get("avg_total_time", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
gen_pct = (
|
||||||
|
timing_metrics.get("avg_generation_time", 0)
|
||||||
|
/ timing_metrics.get("avg_total_time", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print("\n📈 Time Distribution:")
|
||||||
|
print(f" Search: {search_pct:.1f}%")
|
||||||
|
print(f" Generation: {gen_pct:.1f}%")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if baseline exists
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_financebench.py first to build the baseline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.stage == "2" or args.stage == "all":
|
||||||
|
# Stage 2: Recall@3 evaluation
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load FinanceBench queries for testing
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
queries = []
|
||||||
|
with open(args.dataset, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
|
||||||
|
# Test with more queries for robust measurement
|
||||||
|
test_queries = queries[:2000]
|
||||||
|
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||||
|
|
||||||
|
# Test with complexity 64
|
||||||
|
complexity = 64
|
||||||
|
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
|
||||||
|
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
# Shared non-compact index path for Stage 3 and 4
|
||||||
|
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
if args.stage == "3" or args.stage == "all":
|
||||||
|
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
|
||||||
|
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||||
|
print(
|
||||||
|
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create non-compact index for binary search (will be reused in Stage 4)
|
||||||
|
print("🏗️ Creating non-compact index for binary search...")
|
||||||
|
evaluator = FinanceBenchEvaluator(args.index)
|
||||||
|
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact index for binary search
|
||||||
|
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load queries for testing
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
queries = []
|
||||||
|
with open(args.dataset, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
|
||||||
|
# Use more queries for robust measurement
|
||||||
|
test_queries = queries[:200]
|
||||||
|
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||||
|
|
||||||
|
# Binary search for 90% recall complexity (without recompute for speed)
|
||||||
|
target_recall = 0.9
|
||||||
|
min_complexity, max_complexity = 1, 32
|
||||||
|
|
||||||
|
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||||
|
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||||
|
|
||||||
|
best_complexity = None
|
||||||
|
best_recall = 0.0
|
||||||
|
|
||||||
|
while min_complexity <= max_complexity:
|
||||||
|
mid_complexity = (min_complexity + max_complexity) // 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||||
|
)
|
||||||
|
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries, mid_complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recall >= target_recall:
|
||||||
|
best_complexity = mid_complexity
|
||||||
|
best_recall = recall
|
||||||
|
max_complexity = mid_complexity - 1
|
||||||
|
print(" ✅ Target reached! Searching for lower complexity...")
|
||||||
|
else:
|
||||||
|
min_complexity = mid_complexity + 1
|
||||||
|
print(" ❌ Below target. Searching for higher complexity...")
|
||||||
|
|
||||||
|
if best_complexity is not None:
|
||||||
|
print("\n🎯 Optimal complexity found!")
|
||||||
|
print(f" Complexity: {best_complexity}")
|
||||||
|
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||||
|
|
||||||
|
# Test a few complexities around the optimal one for verification
|
||||||
|
print("\n🔬 Verification test around optimal complexity:")
|
||||||
|
verification_complexities = [
|
||||||
|
max(1, best_complexity - 2),
|
||||||
|
max(1, best_complexity - 1),
|
||||||
|
best_complexity,
|
||||||
|
best_complexity + 1,
|
||||||
|
best_complexity + 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for complexity in verification_complexities:
|
||||||
|
if complexity <= 512: # reasonable upper bound
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries, complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
status = "✅" if recall >= target_recall else "❌"
|
||||||
|
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||||
|
|
||||||
|
# Now test the optimal complexity with compact index and recompute for comparison
|
||||||
|
print(
|
||||||
|
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||||
|
)
|
||||||
|
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries[:10], best_complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||||
|
)
|
||||||
|
complexity = best_complexity
|
||||||
|
print(
|
||||||
|
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||||
|
)
|
||||||
|
compact_evaluator.cleanup()
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||||
|
print("All tested complexities were below target.")
|
||||||
|
|
||||||
|
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||||
|
binary_search_evaluator.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
|
||||||
|
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||||
|
|
||||||
|
if args.stage == "4" or args.stage == "all":
|
||||||
|
# Stage 4: Comprehensive evaluation with dual index comparison
|
||||||
|
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
|
||||||
|
|
||||||
|
# Use FinanceBench evaluator for QA evaluation
|
||||||
|
evaluator = FinanceBenchEvaluator(
|
||||||
|
args.index, args.openai_api_key if args.llm_backend == "openai" else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
data = evaluator.load_dataset(args.dataset)
|
||||||
|
|
||||||
|
# Step 1: Analyze current (compact) index
|
||||||
|
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||||
|
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||||
|
compact_size_metrics["index_type"] = "compact"
|
||||||
|
|
||||||
|
# Step 2: Use existing non-compact index or create if needed
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
print(
|
||||||
|
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||||
|
)
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||||
|
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_size_metrics["index_type"] = "non_compact"
|
||||||
|
else:
|
||||||
|
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||||
|
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||||
|
non_compact_index_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Compare index sizes
|
||||||
|
print("\n📊 Index size comparison:")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||||
|
)
|
||||||
|
print("\n📊 Index-only size comparison (.index file only):")
|
||||||
|
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
|
||||||
|
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
|
||||||
|
# Use index-only size for fair comparison (same as Enron emails)
|
||||||
|
storage_saving = (
|
||||||
|
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
|
||||||
|
/ non_compact_size_metrics["index_only_mb"]
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||||
|
|
||||||
|
# Step 4: Performance comparison between the two indexes
|
||||||
|
if complexity is None:
|
||||||
|
raise ValueError("Complexity is required for performance comparison")
|
||||||
|
|
||||||
|
print("\n⚡ Performance comparison between indexes...")
|
||||||
|
performance_metrics = evaluator.compare_index_performance(
|
||||||
|
non_compact_index_path, args.index, data[:10], complexity=complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Generation evaluation
|
||||||
|
test_samples = 20
|
||||||
|
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
|
||||||
|
|
||||||
|
if args.llm_backend == "openai" and args.openai_api_key:
|
||||||
|
print("🔍🤖 Running OpenAI-based generation evaluation...")
|
||||||
|
evaluation_start = time.time()
|
||||||
|
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
|
||||||
|
evaluation_time = time.time() - evaluation_start
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Load LLM
|
||||||
|
if args.llm_backend == "hf":
|
||||||
|
tokenizer, model = load_hf_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_hf(tokenizer, model, prompt)
|
||||||
|
else: # vllm
|
||||||
|
llm, sampling_params = load_vllm_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_vllm(llm, sampling_params, prompt)
|
||||||
|
|
||||||
|
# Simple generation evaluation
|
||||||
|
queries = [item["question"] for item in data[:test_samples]]
|
||||||
|
gen_results = evaluate_rag(
|
||||||
|
evaluator.searcher,
|
||||||
|
llm_func,
|
||||||
|
queries,
|
||||||
|
domain="finance",
|
||||||
|
complexity=complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
timing_metrics = {
|
||||||
|
"total_questions": len(queries),
|
||||||
|
"avg_search_time": gen_results["avg_search_time"],
|
||||||
|
"avg_generation_time": gen_results["avg_generation_time"],
|
||||||
|
"results": gen_results["results"],
|
||||||
|
}
|
||||||
|
evaluation_time = time.time()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Generation evaluation failed: {e}")
|
||||||
|
timing_metrics = {
|
||||||
|
"total_questions": 0,
|
||||||
|
"avg_search_time": 0,
|
||||||
|
"avg_generation_time": 0,
|
||||||
|
}
|
||||||
|
evaluation_time = 0
|
||||||
|
|
||||||
|
# Combine all metrics
|
||||||
|
combined_metrics = {
|
||||||
|
**timing_metrics,
|
||||||
|
"total_evaluation_time": evaluation_time,
|
||||||
|
"current_index": compact_size_metrics,
|
||||||
|
"non_compact_index": non_compact_size_metrics,
|
||||||
|
"performance_comparison": performance_metrics,
|
||||||
|
"storage_saving_percent": storage_saving,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n📊 Generation Results:")
|
||||||
|
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||||
|
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||||
|
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.output:
|
||||||
|
print(f"\n💾 Saving results to {args.output}...")
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(combined_metrics, f, indent=2, default=str)
|
||||||
|
print(f"✅ Results saved to {args.output}")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage == "all":
|
||||||
|
print("🎉 All evaluation stages completed successfully!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" Stage 2: ✅ Recall@3 evaluation completed")
|
||||||
|
print(" Stage 3: ✅ Optimal complexity found")
|
||||||
|
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
|
||||||
|
print("\n🔧 Recommended next steps:")
|
||||||
|
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||||
|
print(" - Review accuracy and timing breakdown for performance optimization")
|
||||||
|
print(" - Run full evaluation on complete dataset if needed")
|
||||||
|
|
||||||
|
# Clean up non-compact index after all stages complete
|
||||||
|
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
temp_index_dir = Path(non_compact_index_path).parent
|
||||||
|
temp_index_name = Path(non_compact_index_path).name
|
||||||
|
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||||
|
temp_file.unlink()
|
||||||
|
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print("📝 No temporary index to clean up")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Evaluation interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
462
benchmarks/financebench/setup_financebench.py
Executable file
462
benchmarks/financebench/setup_financebench.py
Executable file
@@ -0,0 +1,462 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
FinanceBench Complete Setup Script
|
||||||
|
Downloads all PDFs and builds full LEANN datastore
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
import pymupdf
|
||||||
|
import requests
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class FinanceBenchSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||||
|
self.data_dir = self.base_dir / data_dir
|
||||||
|
self.pdf_dir = self.data_dir / "pdfs"
|
||||||
|
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||||
|
self.index_dir = self.data_dir / "index"
|
||||||
|
self.download_lock = Lock()
|
||||||
|
|
||||||
|
def download_dataset(self):
|
||||||
|
"""Download the main FinanceBench dataset"""
|
||||||
|
print("📊 Downloading FinanceBench dataset...")
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if self.dataset_file.exists():
|
||||||
|
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(self.dataset_file, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||||
|
|
||||||
|
def get_pdf_list(self):
|
||||||
|
"""Get list of all PDF files from GitHub"""
|
||||||
|
print("📋 Fetching PDF list from GitHub...")
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
pdf_files = response.json()
|
||||||
|
|
||||||
|
print(f"Found {len(pdf_files)} PDF files")
|
||||||
|
return pdf_files
|
||||||
|
|
||||||
|
def download_single_pdf(self, pdf_info, position):
|
||||||
|
"""Download a single PDF file"""
|
||||||
|
pdf_name = pdf_info["name"]
|
||||||
|
pdf_path = self.pdf_dir / pdf_name
|
||||||
|
|
||||||
|
# Skip if already downloaded
|
||||||
|
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||||
|
return f"✅ {pdf_name} (cached)"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download PDF
|
||||||
|
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Write to file
|
||||||
|
with self.download_lock:
|
||||||
|
with open(pdf_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ {pdf_name}: {e!s}"
|
||||||
|
|
||||||
|
def download_all_pdfs(self, max_workers: int = 5):
|
||||||
|
"""Download all PDF files with parallel processing"""
|
||||||
|
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pdf_files = self.get_pdf_list()
|
||||||
|
|
||||||
|
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# Submit all download tasks
|
||||||
|
future_to_pdf = {
|
||||||
|
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||||
|
for i, pdf_info in enumerate(pdf_files)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process completed downloads with progress bar
|
||||||
|
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||||
|
for future in as_completed(future_to_pdf):
|
||||||
|
result = future.result()
|
||||||
|
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Verify downloads
|
||||||
|
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||||
|
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||||
|
|
||||||
|
# Show any failures
|
||||||
|
missing_pdfs = []
|
||||||
|
for pdf_info in pdf_files:
|
||||||
|
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||||
|
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||||
|
missing_pdfs.append(pdf_info["name"])
|
||||||
|
|
||||||
|
if missing_pdfs:
|
||||||
|
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||||
|
for pdf in missing_pdfs[:5]: # Show first 5
|
||||||
|
print(f" - {pdf}")
|
||||||
|
if len(missing_pdfs) > 5:
|
||||||
|
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||||
|
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
):
|
||||||
|
"""Build LEANN index from all PDFs"""
|
||||||
|
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
# Check if we have PDFs
|
||||||
|
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||||
|
if not pdf_files:
|
||||||
|
raise RuntimeError("No PDF files found! Run download first.")
|
||||||
|
|
||||||
|
print(f"Found {len(pdf_files)} PDF files to process")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Initialize builder with standard compact configuration
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||||
|
is_compact=True, # Enable compact storage (pruned)
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process PDFs and extract text
|
||||||
|
total_chunks = 0
|
||||||
|
failed_pdfs = []
|
||||||
|
|
||||||
|
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||||
|
try:
|
||||||
|
chunks = self.extract_pdf_text(pdf_path)
|
||||||
|
for chunk in chunks:
|
||||||
|
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||||
|
total_chunks += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||||
|
failed_pdfs.append(pdf_path.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build index in index directory
|
||||||
|
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||||
|
print(f"🔨 Building index: {index_path}")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
print("✅ Index built successfully!")
|
||||||
|
print(f" 📁 Index path: {index_path}")
|
||||||
|
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||||
|
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||||
|
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||||
|
|
||||||
|
if failed_pdfs:
|
||||||
|
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||||
|
|
||||||
|
return str(index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
|
||||||
|
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
|
||||||
|
print("🔨 Building FAISS Flat baseline...")
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read metadata from the built index
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.loads(f.read())
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
print(f"📊 Loading passages from {passage_file}...")
|
||||||
|
print(f"🤖 Using embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
# Load all passages for baseline
|
||||||
|
passages = []
|
||||||
|
passage_ids = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"])
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages")
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
print("🧮 Computing embeddings...")
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Build FAISS flat index
|
||||||
|
print("🏗️ Building FAISS IndexFlatIP...")
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# Add embeddings to flat index
|
||||||
|
embeddings_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||||
|
|
||||||
|
# Save index and metadata
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as f:
|
||||||
|
pickle.dump(passage_ids, f)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved to {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||||
|
"""Extract and chunk text from a PDF file"""
|
||||||
|
chunks = []
|
||||||
|
doc = pymupdf.open(pdf_path)
|
||||||
|
|
||||||
|
for page_num in range(len(doc)):
|
||||||
|
page = doc[page_num]
|
||||||
|
text = page.get_text() # type: ignore
|
||||||
|
|
||||||
|
if not text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"source_file": pdf_path.name,
|
||||||
|
"page_number": page_num + 1,
|
||||||
|
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||||
|
"company": pdf_path.name.split("_")[0],
|
||||||
|
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use recursive character splitting like LangChain
|
||||||
|
if len(text.split()) > 500:
|
||||||
|
# Split by double newlines (paragraphs)
|
||||||
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||||
|
|
||||||
|
current_chunk = ""
|
||||||
|
for para in paragraphs:
|
||||||
|
# If adding this paragraph would make chunk too long, save current chunk
|
||||||
|
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": current_chunk.strip(),
|
||||||
|
"metadata": {
|
||||||
|
**metadata,
|
||||||
|
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
current_chunk = para
|
||||||
|
else:
|
||||||
|
current_chunk = (current_chunk + " " + para).strip()
|
||||||
|
|
||||||
|
# Add the last chunk
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": current_chunk.strip(),
|
||||||
|
"metadata": {
|
||||||
|
**metadata,
|
||||||
|
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Page is short enough, use as single chunk
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": text.strip(),
|
||||||
|
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def extract_year_from_filename(self, filename: str) -> str:
|
||||||
|
"""Extract year from PDF filename"""
|
||||||
|
# Try to find 4-digit year in filename
|
||||||
|
|
||||||
|
match = re.search(r"(\d{4})", filename)
|
||||||
|
return match.group(1) if match else "unknown"
|
||||||
|
|
||||||
|
def verify_setup(self, index_path: str):
|
||||||
|
"""Verify the setup by testing a simple query"""
|
||||||
|
print("🧪 Verifying setup with test query...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Test query
|
||||||
|
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||||
|
results = searcher.search(test_query, top_k=3)
|
||||||
|
|
||||||
|
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
company = result.metadata.get("company", "Unknown")
|
||||||
|
year = result.metadata.get("doc_period", "Unknown")
|
||||||
|
page = result.metadata.get("page_number", "Unknown")
|
||||||
|
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||||
|
print(f" {result.text[:100]}...")
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
print("✅ Setup verification completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Setup verification failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model",
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||||
|
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||||
|
parser.add_argument(
|
||||||
|
"--build-baseline-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only build FAISS baseline from existing index",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("🏦 FinanceBench Complete Setup")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
setup = FinanceBenchSetup(args.data_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.build_baseline_only:
|
||||||
|
# Only build baseline from existing index
|
||||||
|
index_path = setup.index_dir / f"financebench_full_{args.backend}"
|
||||||
|
index_file = f"{index_path}.index"
|
||||||
|
meta_file = f"{index_path}.leann.meta.json"
|
||||||
|
|
||||||
|
if not os.path.exists(index_file) or not os.path.exists(meta_file):
|
||||||
|
print("❌ Index files not found:")
|
||||||
|
print(f" Index: {index_file}")
|
||||||
|
print(f" Meta: {meta_file}")
|
||||||
|
print("💡 Run without --build-baseline-only to build the index first")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print(f"🔨 Building baseline from existing index: {index_path}")
|
||||||
|
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
|
||||||
|
print(f"✅ Baseline built at {baseline_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 1: Download dataset
|
||||||
|
setup.download_dataset()
|
||||||
|
|
||||||
|
# Step 2: Download PDFs
|
||||||
|
if not args.skip_download:
|
||||||
|
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping PDF download")
|
||||||
|
|
||||||
|
# Step 3: Build LEANN index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
backend=args.backend, embedding_model=args.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Build FAISS flat baseline
|
||||||
|
print("\n🔨 Building FAISS flat baseline...")
|
||||||
|
baseline_path = setup.build_faiss_flat_baseline(index_path)
|
||||||
|
print(f"✅ Baseline built at {baseline_path}")
|
||||||
|
|
||||||
|
# Step 5: Verify setup
|
||||||
|
setup.verify_setup(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping index building")
|
||||||
|
|
||||||
|
print("\n🎉 FinanceBench setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("\nNext steps:")
|
||||||
|
print(
|
||||||
|
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Setup interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Setup failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
214
benchmarks/financebench/verify_recall.py
Normal file
214
benchmarks/financebench/verify_recall.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.9"
|
||||||
|
# dependencies = [
|
||||||
|
# "faiss-cpu",
|
||||||
|
# "numpy",
|
||||||
|
# "sentence-transformers",
|
||||||
|
# "torch",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
"""
|
||||||
|
Independent recall verification script using standard FAISS.
|
||||||
|
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Direct embedding computation using sentence-transformers.
|
||||||
|
Copied logic to avoid dependency issues.
|
||||||
|
"""
|
||||||
|
print(f"Loading model: {model_name}")
|
||||||
|
model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
|
print(f"Computing embeddings for {len(chunks)} chunks...")
|
||||||
|
embeddings = model.encode(
|
||||||
|
chunks,
|
||||||
|
show_progress_bar=True,
|
||||||
|
batch_size=32,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
|
||||||
|
"""Load FinanceBench queries from dataset"""
|
||||||
|
queries = []
|
||||||
|
with open(dataset_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
if len(queries) >= max_queries:
|
||||||
|
break
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
|
||||||
|
"""Load passages from LEANN index structure"""
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
passage_file = index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"Loading passages from {passage_file}")
|
||||||
|
|
||||||
|
passages = []
|
||||||
|
passage_ids = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in tqdm(f, desc="Loading passages"):
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"])
|
||||||
|
|
||||||
|
print(f"Loaded {len(passages)} passages")
|
||||||
|
return passages, passage_ids
|
||||||
|
|
||||||
|
|
||||||
|
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
|
||||||
|
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
|
||||||
|
# Build Flat index (ground truth)
|
||||||
|
print("Building FAISS IndexFlatIP (ground truth)...")
|
||||||
|
flat_index = faiss.IndexFlatIP(dimension)
|
||||||
|
flat_index.add(embeddings)
|
||||||
|
|
||||||
|
# Build HNSW index
|
||||||
|
print("Building FAISS IndexHNSWFlat...")
|
||||||
|
M = 32 # Same as LEANN default
|
||||||
|
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
|
||||||
|
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
|
||||||
|
hnsw_index.add(embeddings)
|
||||||
|
|
||||||
|
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
|
||||||
|
return flat_index, hnsw_index
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_recall_at_k(
|
||||||
|
query_embeddings: np.ndarray,
|
||||||
|
flat_index: faiss.Index,
|
||||||
|
hnsw_index: faiss.Index,
|
||||||
|
passage_ids: list[str],
|
||||||
|
k: int = 3,
|
||||||
|
ef_search: int = 64,
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@k comparing HNSW vs Flat"""
|
||||||
|
|
||||||
|
# Set search parameters for HNSW
|
||||||
|
hnsw_index.hnsw.efSearch = ef_search
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = query_embeddings.shape[0]
|
||||||
|
|
||||||
|
for i in range(num_queries):
|
||||||
|
query = query_embeddings[i : i + 1] # Keep 2D shape
|
||||||
|
|
||||||
|
# Get ground truth from Flat index (standard FAISS API)
|
||||||
|
flat_distances, flat_indices = flat_index.search(query, k)
|
||||||
|
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
|
||||||
|
|
||||||
|
# Get results from HNSW index (standard FAISS API)
|
||||||
|
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
|
||||||
|
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
|
||||||
|
|
||||||
|
# Calculate recall
|
||||||
|
intersection = ground_truth_ids.intersection(hnsw_ids)
|
||||||
|
recall = len(intersection) / k
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
|
||||||
|
print(f" Flat: {list(ground_truth_ids)}")
|
||||||
|
print(f" HNSW: {list(hnsw_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Configuration
|
||||||
|
dataset_path = "data/financebench_merged.jsonl"
|
||||||
|
index_path = "data/index/financebench_full_hnsw.leann"
|
||||||
|
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
|
||||||
|
print("🔍 FAISS Recall Verification")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Check if files exist
|
||||||
|
if not Path(dataset_path).exists():
|
||||||
|
print(f"❌ Dataset not found: {dataset_path}")
|
||||||
|
return
|
||||||
|
if not Path(f"{index_path}.meta.json").exists():
|
||||||
|
print(f"❌ Index metadata not found: {index_path}.meta.json")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
print("📖 Loading FinanceBench queries...")
|
||||||
|
queries = load_financebench_queries(dataset_path, max_queries=50)
|
||||||
|
print(f"Loaded {len(queries)} queries")
|
||||||
|
|
||||||
|
print("📄 Loading passages from LEANN index...")
|
||||||
|
passages, passage_ids = load_passages_from_leann_index(index_path)
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
|
print("🧮 Computing passage embeddings...")
|
||||||
|
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
|
||||||
|
|
||||||
|
print("🧮 Computing query embeddings...")
|
||||||
|
query_embeddings = compute_embeddings_direct(queries, embedding_model)
|
||||||
|
|
||||||
|
# Build FAISS indexes
|
||||||
|
print("🏗️ Building FAISS indexes...")
|
||||||
|
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
|
||||||
|
|
||||||
|
# Test different efSearch values (equivalent to LEANN complexity)
|
||||||
|
print("\n📊 Evaluating Recall@3 at different efSearch values...")
|
||||||
|
ef_search_values = [16, 32, 64, 128, 256]
|
||||||
|
|
||||||
|
for ef_search in ef_search_values:
|
||||||
|
print(f"\n🧪 Testing efSearch = {ef_search}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
recall = evaluate_recall_at_k(
|
||||||
|
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(
|
||||||
|
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ Verification completed!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" - Built independent FAISS Flat and HNSW indexes")
|
||||||
|
print(" - Compared recall@3 at different efSearch values")
|
||||||
|
print(" - Used same embedding model as LEANN")
|
||||||
|
print(" - This validates LEANN's recall measurements")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
benchmarks/laion/.gitignore
vendored
Normal file
1
benchmarks/laion/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
data/
|
||||||
199
benchmarks/laion/README.md
Normal file
199
benchmarks/laion/README.md
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# LAION Multimodal Benchmark
|
||||||
|
|
||||||
|
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This benchmark evaluates:
|
||||||
|
- **Image retrieval timing** using caption-based queries
|
||||||
|
- **Recall@K performance** for image search
|
||||||
|
- **Complexity analysis** across different search parameters
|
||||||
|
- **Index size and storage efficiency**
|
||||||
|
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
|
||||||
|
|
||||||
|
## Dataset Configuration
|
||||||
|
|
||||||
|
- **Dataset**: LAION-400M subset (10,000 images)
|
||||||
|
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
|
||||||
|
- **Queries**: 200 random captions from the dataset
|
||||||
|
- **Ground Truth**: Self-recall (query caption → original image)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Setup the benchmark
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd benchmarks/laion
|
||||||
|
python setup_laion.py --num-samples 10000 --num-queries 200
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Create dummy LAION data (10K samples)
|
||||||
|
- Generate CLIP embeddings (512-dim)
|
||||||
|
- Build LEANN index with HNSW backend
|
||||||
|
- Create 200 evaluation queries
|
||||||
|
|
||||||
|
### 2. Run evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all evaluation stages
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann
|
||||||
|
|
||||||
|
# Run specific stages
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
|
||||||
|
|
||||||
|
# Multimodal generation with Qwen2.5-VL
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Save results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --output results.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### Setup Options
|
||||||
|
```bash
|
||||||
|
python setup_laion.py \
|
||||||
|
--num-samples 10000 \
|
||||||
|
--num-queries 200 \
|
||||||
|
--index-path data/laion_index.leann \
|
||||||
|
--backend hnsw
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation Options
|
||||||
|
```bash
|
||||||
|
python evaluate_laion.py \
|
||||||
|
--index data/laion_index.leann \
|
||||||
|
--queries data/evaluation_queries.jsonl \
|
||||||
|
--complexity 64 \
|
||||||
|
--top-k 3 \
|
||||||
|
--num-samples 100 \
|
||||||
|
--stage all
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Stages
|
||||||
|
|
||||||
|
### Stage 2: Recall Evaluation
|
||||||
|
- Evaluates Recall@3 for multimodal retrieval
|
||||||
|
- Compares LEANN vs FAISS baseline performance
|
||||||
|
- Self-recall: query caption should retrieve original image
|
||||||
|
|
||||||
|
### Stage 3: Complexity Analysis
|
||||||
|
- Binary search for optimal complexity (90% recall target)
|
||||||
|
- Tests performance across different complexity levels
|
||||||
|
- Analyzes speed vs. accuracy tradeoffs
|
||||||
|
|
||||||
|
### Stage 4: Index Comparison
|
||||||
|
- Compares compact vs non-compact index sizes
|
||||||
|
- Measures search performance differences
|
||||||
|
- Reports storage efficiency and speed ratios
|
||||||
|
|
||||||
|
### Stage 5: Multimodal Generation
|
||||||
|
- Uses Qwen2.5-VL for image understanding and description
|
||||||
|
- Retrieval-Augmented Generation (RAG) with multimodal context
|
||||||
|
- Measures both search and generation timing
|
||||||
|
|
||||||
|
## Output Metrics
|
||||||
|
|
||||||
|
### Timing Metrics
|
||||||
|
- Average/median/min/max search time
|
||||||
|
- Standard deviation
|
||||||
|
- Searches per second
|
||||||
|
- Latency in milliseconds
|
||||||
|
|
||||||
|
### Recall Metrics
|
||||||
|
- Recall@3 percentage for image retrieval
|
||||||
|
- Number of queries with ground truth
|
||||||
|
|
||||||
|
### Index Metrics
|
||||||
|
- Total index size (MB)
|
||||||
|
- Component breakdown (index, passages, metadata)
|
||||||
|
- Storage savings (compact vs non-compact)
|
||||||
|
- Backend and embedding model info
|
||||||
|
|
||||||
|
### Generation Metrics (Stage 5)
|
||||||
|
- Average search time per query
|
||||||
|
- Average generation time per query
|
||||||
|
- Time distribution (search vs generation)
|
||||||
|
- Sample multimodal responses
|
||||||
|
- Model: Qwen2.5-VL performance
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
|
||||||
|
|
||||||
|
**Stage 3: Optimal Complexity Analysis**
|
||||||
|
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
|
||||||
|
- **Binary Search Range**: 1-128
|
||||||
|
- **Target Recall**: 90%
|
||||||
|
- **Index Type**: Non-compact (for fast binary search)
|
||||||
|
|
||||||
|
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
|
||||||
|
- **Total Queries**: 20
|
||||||
|
- **Average Search Time**: 1.200s per query
|
||||||
|
- **Average Generation Time**: 6.558s per query
|
||||||
|
- **Time Distribution**: Search 15.5%, Generation 84.5%
|
||||||
|
- **LLM Backend**: HuggingFace transformers
|
||||||
|
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
- **Optimal Complexity**: 85
|
||||||
|
|
||||||
|
**System Performance:**
|
||||||
|
- **Index Size**: ~10,000 image embeddings from LAION subset
|
||||||
|
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
|
||||||
|
- **Backend**: HNSW with cosine distance
|
||||||
|
|
||||||
|
### Example Results
|
||||||
|
|
||||||
|
```
|
||||||
|
🎯 LAION MULTIMODAL BENCHMARK RESULTS
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
📊 Multimodal Generation Results:
|
||||||
|
Total Queries: 20
|
||||||
|
Avg Search Time: 1.200s
|
||||||
|
Avg Generation Time: 6.558s
|
||||||
|
Time Distribution: Search 15.5%, Generation 84.5%
|
||||||
|
LLM Backend: HuggingFace transformers
|
||||||
|
Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
|
||||||
|
⚙️ Optimal Complexity Analysis:
|
||||||
|
Target Recall: 90%
|
||||||
|
Optimal Complexity: 85
|
||||||
|
Binary Search Range: 1-128
|
||||||
|
Non-compact Index (fast search, no recompute)
|
||||||
|
|
||||||
|
🚀 Performance Summary:
|
||||||
|
Multimodal RAG: 7.758s total per query
|
||||||
|
Search: 15.5% of total time
|
||||||
|
Generation: 84.5% of total time
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
benchmarks/laion/
|
||||||
|
├── setup_laion.py # Setup script
|
||||||
|
├── evaluate_laion.py # Evaluation script
|
||||||
|
├── README.md # This file
|
||||||
|
└── data/ # Generated data
|
||||||
|
├── laion_images/ # Image files (placeholder)
|
||||||
|
├── laion_metadata.jsonl # Image metadata
|
||||||
|
├── laion_passages.jsonl # LEANN passages
|
||||||
|
├── laion_embeddings.npy # CLIP embeddings
|
||||||
|
├── evaluation_queries.jsonl # Evaluation queries
|
||||||
|
└── laion_index.leann/ # LEANN index files
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Current implementation uses dummy data for demonstration
|
||||||
|
- For real LAION data, implement actual download logic in `setup_laion.py`
|
||||||
|
- CLIP embeddings are randomly generated - replace with real CLIP model for production
|
||||||
|
- Adjust `num_samples` and `num_queries` based on available resources
|
||||||
|
- Consider using `--num-samples` during evaluation for faster testing
|
||||||
725
benchmarks/laion/evaluate_laion.py
Normal file
725
benchmarks/laion/evaluate_laion.py
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
"""
|
||||||
|
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Load FAISS flat baseline (image embeddings)
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.image_ids = pickle.load(f)
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
|
||||||
|
|
||||||
|
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
|
||||||
|
self.st_clip = SentenceTransformer("clip-ViT-L-14")
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = len(captions)
|
||||||
|
|
||||||
|
for i, caption in enumerate(captions):
|
||||||
|
# Get ground truth: search with FAISS flat using caption text embedding
|
||||||
|
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
|
||||||
|
query_embedding = self.st_clip.encode(
|
||||||
|
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||||
|
n = query_embedding.shape[0] # Number of queries
|
||||||
|
k = 3 # Number of nearest neighbors
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(query_embedding),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the results (image IDs from FAISS)
|
||||||
|
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN at specified complexity (using caption as text query)
|
||||||
|
test_results = self.searcher.search(
|
||||||
|
caption,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {result.id for result in test_results}
|
||||||
|
|
||||||
|
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class LAIONEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
"""Load caption queries from evaluation file"""
|
||||||
|
captions = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
query_data = json.loads(line)
|
||||||
|
captions.append(query_data["query"])
|
||||||
|
|
||||||
|
print(f"📊 Loaded {len(captions)} caption queries")
|
||||||
|
return captions
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
|
||||||
|
# Get all index-related files
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem # Remove .leann extension
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
|
||||||
|
# Core index files
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||||
|
|
||||||
|
# Core index size (.index only)
|
||||||
|
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
sizes["index_only_mb"] = index_mb
|
||||||
|
|
||||||
|
# Other files for reference (not counted in index_only_mb)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {index_mb:.1f} MB")
|
||||||
|
if sizes["metadata_mb"]:
|
||||||
|
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
|
||||||
|
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
|
||||||
|
print(
|
||||||
|
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building non-compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Load CLIP embeddings
|
||||||
|
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
|
||||||
|
embeddings = np.load(embeddings_file)
|
||||||
|
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=False, # Disable recompute (store embeddings)
|
||||||
|
is_compact=False, # Disable compact storage
|
||||||
|
distance_metric="cosine",
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact", "distance_metric"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare ids and add passages
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
# Ensure metadata contains the id used by the vector index
|
||||||
|
metadata = {**data.get("metadata", {}), "id": data["id"]}
|
||||||
|
builder.add_text(text=data["text"], metadata=metadata)
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare performance between non-compact and compact indexes"""
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
|
||||||
|
# Test queries
|
||||||
|
test_queries = test_captions[:5]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test non-compact index (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
|
||||||
|
for caption in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
caption, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["non_compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Test compact index (with recompute)
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
|
||||||
|
for caption in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
caption, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance ratio
|
||||||
|
if results["avg_search_times"]["compact"] > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||||
|
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _print_results(self, timing_metrics: dict):
|
||||||
|
"""Print evaluation results"""
|
||||||
|
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Index comparison analysis (prefer .index-only view if present)
|
||||||
|
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||||
|
current = timing_metrics["current_index"]
|
||||||
|
non_compact = timing_metrics["non_compact_index"]
|
||||||
|
|
||||||
|
if "index_only_mb" in current and "index_only_mb" in non_compact:
|
||||||
|
print("\n📏 Index Comparison Analysis (.index only):")
|
||||||
|
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
|
||||||
|
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
# Show excluded components for reference if available
|
||||||
|
if any(
|
||||||
|
k in non_compact
|
||||||
|
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
|
||||||
|
):
|
||||||
|
print(" (passages excluded in totals, shown for reference):")
|
||||||
|
print(
|
||||||
|
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
|
||||||
|
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
|
||||||
|
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to legacy totals if running with older metrics
|
||||||
|
print("\n📏 Index Comparison Analysis:")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
print(" Component breakdown (non-compact):")
|
||||||
|
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||||
|
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||||
|
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||||
|
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
if "performance_comparison" in timing_metrics:
|
||||||
|
perf = timing_metrics["performance_comparison"]
|
||||||
|
print("\n⚡ Performance Comparison:")
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
# Legacy single index analysis (fallback)
|
||||||
|
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||||
|
print("\n📏 Index Size Analysis:")
|
||||||
|
print(
|
||||||
|
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "5", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend",
|
||||||
|
choices=["hf"],
|
||||||
|
default="hf",
|
||||||
|
help="LLM backend (Qwen2.5-VL only supports HF)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if baseline exists
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_laion.py first to build the baseline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.stage == "2" or args.stage == "all":
|
||||||
|
# Stage 2: Recall@3 evaluation
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
|
||||||
|
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load caption queries for testing
|
||||||
|
laion_evaluator = LAIONEvaluator(args.index)
|
||||||
|
captions = laion_evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Test with queries for robust measurement
|
||||||
|
test_captions = captions[:100] # Use subset for speed
|
||||||
|
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||||
|
|
||||||
|
# Test with complexity 64
|
||||||
|
complexity = 64
|
||||||
|
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
|
||||||
|
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
# Shared non-compact index path for Stage 3 and 4
|
||||||
|
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
if args.stage == "3" or args.stage == "all":
|
||||||
|
# Stage 3: Binary search for 90% recall complexity
|
||||||
|
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||||
|
print(
|
||||||
|
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create non-compact index for binary search
|
||||||
|
print("🏗️ Creating non-compact index for binary search...")
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact index for binary search
|
||||||
|
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load caption queries for testing
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Use subset for robust measurement
|
||||||
|
test_captions = captions[:50] # Smaller subset for binary search speed
|
||||||
|
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||||
|
|
||||||
|
# Binary search for 90% recall complexity
|
||||||
|
target_recall = 0.9
|
||||||
|
min_complexity, max_complexity = 1, 128
|
||||||
|
|
||||||
|
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||||
|
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||||
|
|
||||||
|
best_complexity = None
|
||||||
|
best_recall = 0.0
|
||||||
|
|
||||||
|
while min_complexity <= max_complexity:
|
||||||
|
mid_complexity = (min_complexity + max_complexity) // 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||||
|
)
|
||||||
|
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions, mid_complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recall >= target_recall:
|
||||||
|
best_complexity = mid_complexity
|
||||||
|
best_recall = recall
|
||||||
|
max_complexity = mid_complexity - 1
|
||||||
|
print(" ✅ Target reached! Searching for lower complexity...")
|
||||||
|
else:
|
||||||
|
min_complexity = mid_complexity + 1
|
||||||
|
print(" ❌ Below target. Searching for higher complexity...")
|
||||||
|
|
||||||
|
if best_complexity is not None:
|
||||||
|
print("\n🎯 Optimal complexity found!")
|
||||||
|
print(f" Complexity: {best_complexity}")
|
||||||
|
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||||
|
|
||||||
|
# Test a few complexities around the optimal one for verification
|
||||||
|
print("\n🔬 Verification test around optimal complexity:")
|
||||||
|
verification_complexities = [
|
||||||
|
max(1, best_complexity - 2),
|
||||||
|
max(1, best_complexity - 1),
|
||||||
|
best_complexity,
|
||||||
|
best_complexity + 1,
|
||||||
|
best_complexity + 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for complexity in verification_complexities:
|
||||||
|
if complexity <= 512: # reasonable upper bound
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions, complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
status = "✅" if recall >= target_recall else "❌"
|
||||||
|
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||||
|
|
||||||
|
# Now test the optimal complexity with compact index and recompute for comparison
|
||||||
|
print(
|
||||||
|
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||||
|
)
|
||||||
|
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions[:10], best_complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||||
|
)
|
||||||
|
complexity = best_complexity
|
||||||
|
print(
|
||||||
|
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||||
|
)
|
||||||
|
compact_evaluator.cleanup()
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||||
|
print("All tested complexities were below target.")
|
||||||
|
|
||||||
|
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||||
|
binary_search_evaluator.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
|
||||||
|
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||||
|
|
||||||
|
if args.stage == "4" or args.stage == "all":
|
||||||
|
# Stage 4: Index comparison (without LLM generation)
|
||||||
|
print("🚀 Starting Stage 4: Index comparison analysis")
|
||||||
|
|
||||||
|
# Use LAION evaluator for index comparison
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
|
||||||
|
# Load caption queries
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Step 1: Analyze current (compact) index
|
||||||
|
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||||
|
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||||
|
compact_size_metrics["index_type"] = "compact"
|
||||||
|
|
||||||
|
# Step 2: Use existing non-compact index or create if needed
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
print(
|
||||||
|
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||||
|
)
|
||||||
|
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||||
|
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_size_metrics["index_type"] = "non_compact"
|
||||||
|
else:
|
||||||
|
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||||
|
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||||
|
non_compact_index_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Compare index sizes (.index only)
|
||||||
|
print("\n📊 Index size comparison (.index only):")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
|
||||||
|
|
||||||
|
storage_saving = 0.0
|
||||||
|
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
|
||||||
|
storage_saving = (
|
||||||
|
(
|
||||||
|
non_compact_size_metrics.get("index_only_mb", 0)
|
||||||
|
- compact_size_metrics.get("index_only_mb", 0)
|
||||||
|
)
|
||||||
|
/ non_compact_size_metrics.get("index_only_mb", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||||
|
|
||||||
|
# Step 4: Performance comparison between the two indexes
|
||||||
|
if complexity is None:
|
||||||
|
raise ValueError("Complexity is required for index comparison")
|
||||||
|
|
||||||
|
print("\n⚡ Performance comparison between indexes...")
|
||||||
|
performance_metrics = evaluator.compare_index_performance(
|
||||||
|
non_compact_index_path, args.index, captions[:10], complexity=complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all metrics
|
||||||
|
combined_metrics = {
|
||||||
|
"current_index": compact_size_metrics,
|
||||||
|
"non_compact_index": non_compact_size_metrics,
|
||||||
|
"performance_comparison": performance_metrics,
|
||||||
|
"storage_saving_percent": storage_saving,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print comprehensive results
|
||||||
|
evaluator._print_results(combined_metrics)
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.output:
|
||||||
|
print(f"\n💾 Saving results to {args.output}...")
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(combined_metrics, f, indent=2, default=str)
|
||||||
|
print(f"✅ Results saved to {args.output}")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("5", "all"):
|
||||||
|
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
test_captions = captions[: min(20, len(captions))] # Use subset for generation
|
||||||
|
|
||||||
|
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
|
||||||
|
|
||||||
|
# Load Qwen2.5-VL model
|
||||||
|
try:
|
||||||
|
print("Loading Qwen2.5-VL model...")
|
||||||
|
processor, model = load_qwen_vl_model(args.model_name)
|
||||||
|
|
||||||
|
# Run multimodal generation evaluation
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
gen_results = evaluate_multimodal_rag(
|
||||||
|
evaluator.searcher,
|
||||||
|
test_captions,
|
||||||
|
processor=processor,
|
||||||
|
model=model,
|
||||||
|
complexity=complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Multimodal Generation Results:")
|
||||||
|
print(f" Total Queries: {len(test_captions)}")
|
||||||
|
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
|
||||||
|
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
|
||||||
|
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
|
||||||
|
search_pct = (gen_results["avg_search_time"] / total_time) * 100
|
||||||
|
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
|
||||||
|
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
|
||||||
|
print(" LLM Backend: HuggingFace transformers")
|
||||||
|
print(f" Model: {args.model_name}")
|
||||||
|
|
||||||
|
# Show sample results
|
||||||
|
print("\n📝 Sample Multimodal Generations:")
|
||||||
|
for i, response in enumerate(gen_results["results"][:3]):
|
||||||
|
# Handle both string and dict formats for captions
|
||||||
|
if isinstance(test_captions[i], dict):
|
||||||
|
caption_text = test_captions[i].get("query", str(test_captions[i]))
|
||||||
|
else:
|
||||||
|
caption_text = str(test_captions[i])
|
||||||
|
print(f" Query {i + 1}: {caption_text[:60]}...")
|
||||||
|
print(f" Response {i + 1}: {response[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Multimodal generation evaluation failed: {e}")
|
||||||
|
print("💡 Make sure transformers and Qwen2.5-VL are installed")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 5 completed!\n")
|
||||||
|
|
||||||
|
if args.stage == "all":
|
||||||
|
print("🎉 All evaluation stages completed successfully!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
|
||||||
|
print(" Stage 3: ✅ Optimal complexity found")
|
||||||
|
print(" Stage 4: ✅ Index comparison analysis completed")
|
||||||
|
print(" Stage 5: ✅ Multimodal generation evaluation completed")
|
||||||
|
print("\n🔧 Recommended next steps:")
|
||||||
|
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||||
|
print(" - Review index comparison for storage vs performance tradeoffs")
|
||||||
|
|
||||||
|
# Clean up non-compact index after all stages complete
|
||||||
|
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
temp_index_dir = Path(non_compact_index_path).parent
|
||||||
|
temp_index_name = Path(non_compact_index_path).name
|
||||||
|
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||||
|
temp_file.unlink()
|
||||||
|
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print("📝 No temporary index to clean up")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Evaluation interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
576
benchmarks/laion/setup_laion.py
Normal file
576
benchmarks/laion/setup_laion.py
Normal file
@@ -0,0 +1,576 @@
|
|||||||
|
"""
|
||||||
|
LAION Multimodal Benchmark Setup Script
|
||||||
|
Downloads LAION subset and builds LEANN index with sentence embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from datasets import load_dataset
|
||||||
|
from leann import LeannBuilder
|
||||||
|
from PIL import Image
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class LAIONSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.images_dir = self.data_dir / "laion_images"
|
||||||
|
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
|
||||||
|
|
||||||
|
# Create directories
|
||||||
|
self.data_dir.mkdir(exist_ok=True)
|
||||||
|
self.images_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
|
||||||
|
"""Download a single image asynchronously"""
|
||||||
|
async with semaphore: # Limit concurrent downloads
|
||||||
|
try:
|
||||||
|
image_url = sample_data["url"]
|
||||||
|
image_path = sample_data["image_path"]
|
||||||
|
|
||||||
|
# Skip if already exists
|
||||||
|
if os.path.exists(image_path):
|
||||||
|
progress_bar.update(1)
|
||||||
|
return sample_data
|
||||||
|
|
||||||
|
async with session.get(image_url, timeout=10) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
|
||||||
|
# Verify it's a valid image
|
||||||
|
try:
|
||||||
|
img = Image.open(io.BytesIO(content))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.save(image_path, "JPEG")
|
||||||
|
progress_bar.update(1)
|
||||||
|
return sample_data
|
||||||
|
except Exception:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None # Skip invalid images
|
||||||
|
else:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def download_laion_subset(self, num_samples: int = 1000):
|
||||||
|
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
|
||||||
|
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
|
||||||
|
|
||||||
|
# Load LAION-400M subset from HuggingFace
|
||||||
|
print("🤗 Loading from HuggingFace datasets...")
|
||||||
|
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
|
||||||
|
|
||||||
|
# Collect sample metadata first (fast)
|
||||||
|
print("📋 Collecting sample metadata...")
|
||||||
|
candidates = []
|
||||||
|
for sample in dataset:
|
||||||
|
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
|
||||||
|
break
|
||||||
|
|
||||||
|
image_url = sample.get("url", "")
|
||||||
|
caption = sample.get("caption", "")
|
||||||
|
|
||||||
|
if not image_url or not caption:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_filename = f"laion_{len(candidates):06d}.jpg"
|
||||||
|
image_path = self.images_dir / image_filename
|
||||||
|
|
||||||
|
candidate = {
|
||||||
|
"id": f"laion_{len(candidates):06d}",
|
||||||
|
"url": image_url,
|
||||||
|
"caption": caption,
|
||||||
|
"image_path": str(image_path),
|
||||||
|
"width": sample.get("original_width", 512),
|
||||||
|
"height": sample.get("original_height", 512),
|
||||||
|
"similarity": sample.get("similarity", 0.0),
|
||||||
|
}
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download images in parallel
|
||||||
|
async def download_batch():
|
||||||
|
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
|
||||||
|
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30)
|
||||||
|
|
||||||
|
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||||
|
tasks = []
|
||||||
|
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
|
||||||
|
task = self.download_single_image(session, candidate, semaphore, progress_bar)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all downloads
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
# Filter successful downloads
|
||||||
|
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
|
||||||
|
return successful[:num_samples]
|
||||||
|
|
||||||
|
# Run async download
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
samples = loop.run_until_complete(download_batch())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Save metadata
|
||||||
|
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
||||||
|
for sample in samples:
|
||||||
|
f.write(json.dumps(sample) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def generate_clip_image_embeddings(self, samples: list[dict]):
|
||||||
|
"""Generate CLIP image embeddings for downloaded images"""
|
||||||
|
print("🔍 Generating CLIP image embeddings...")
|
||||||
|
|
||||||
|
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
|
||||||
|
# This single model can encode both images and text.
|
||||||
|
model = SentenceTransformer("clip-ViT-L-14")
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
valid_samples = []
|
||||||
|
|
||||||
|
for sample in tqdm(samples, desc="Processing images"):
|
||||||
|
try:
|
||||||
|
# Load image
|
||||||
|
image_path = sample["image_path"]
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# Encode image to 768-dim embedding via sentence-transformers (normalized)
|
||||||
|
vec = model.encode(
|
||||||
|
[image],
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=1,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)[0]
|
||||||
|
embeddings.append(vec.astype(np.float32))
|
||||||
|
valid_samples.append(sample)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Failed to process {sample['id']}: {e}")
|
||||||
|
# Skip invalid images
|
||||||
|
|
||||||
|
embeddings = np.array(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Save embeddings
|
||||||
|
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
np.save(embeddings_file, embeddings)
|
||||||
|
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
return embeddings, valid_samples
|
||||||
|
|
||||||
|
def build_faiss_baseline(
|
||||||
|
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
|
||||||
|
):
|
||||||
|
"""Build FAISS flat baseline using CLIP image embeddings"""
|
||||||
|
print("🔨 Building FAISS Flat baseline...")
|
||||||
|
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Extract image IDs (must be present)
|
||||||
|
if not samples or "id" not in samples[0]:
|
||||||
|
raise KeyError("samples missing 'id' field for FAISS baseline")
|
||||||
|
image_ids: list[str] = [str(sample["id"]) for sample in samples]
|
||||||
|
|
||||||
|
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||||
|
print(f"📄 Processing {len(image_ids)} images")
|
||||||
|
|
||||||
|
# Build FAISS flat index
|
||||||
|
print("🏗️ Building FAISS IndexFlatIP...")
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# Add embeddings to flat index
|
||||||
|
embeddings_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||||
|
|
||||||
|
# Save index and metadata
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as f:
|
||||||
|
pickle.dump(image_ids, f)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved to {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
def create_leann_passages(self, samples: list[dict]):
|
||||||
|
"""Create LEANN-compatible passages from LAION data"""
|
||||||
|
print("📝 Creating LEANN passages...")
|
||||||
|
|
||||||
|
passages_file = self.data_dir / "laion_passages.jsonl"
|
||||||
|
|
||||||
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
passage = {
|
||||||
|
"id": sample["id"],
|
||||||
|
"text": sample["caption"], # Use caption as searchable text
|
||||||
|
"metadata": {
|
||||||
|
"image_url": sample["url"],
|
||||||
|
"image_path": sample.get("image_path", ""),
|
||||||
|
"width": sample["width"],
|
||||||
|
"height": sample["height"],
|
||||||
|
"similarity": sample["similarity"],
|
||||||
|
"image_index": i, # Index for embedding lookup
|
||||||
|
},
|
||||||
|
}
|
||||||
|
f.write(json.dumps(passage) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Created {len(samples)} passages")
|
||||||
|
return passages_file
|
||||||
|
|
||||||
|
def build_compact_index(
|
||||||
|
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
|
||||||
|
print(f"🏗️ Building compact LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
|
||||||
|
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
np.save(npy_path, embeddings)
|
||||||
|
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||||
|
|
||||||
|
# Prepare ids in the same order as passages_file (matches embeddings order)
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
rec = json.loads(line)
|
||||||
|
ids.append(str(rec["id"]))
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||||
|
|
||||||
|
# Initialize builder - compact with recompute
|
||||||
|
# Note: For multimodal case, we need to handle embeddings differently
|
||||||
|
# Let's try using sentence-transformers mode but with custom embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
# HNSW params (or forwarded to chosen backend)
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
# Compact/pruned with recompute at query time
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add passages (text + metadata)
|
||||||
|
print("📚 Adding passages...")
|
||||||
|
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||||
|
|
||||||
|
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
|
||||||
|
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
print(f"✅ Compact index built in {build_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze index size
|
||||||
|
self._analyze_index_size(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def build_non_compact_index(
|
||||||
|
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
|
||||||
|
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Ensure embeddings are saved (npy + pickle)
|
||||||
|
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
if not npy_path.exists():
|
||||||
|
np.save(npy_path, embeddings)
|
||||||
|
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||||
|
# Prepare ids in same order as passages_file
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
rec = json.loads(line)
|
||||||
|
ids.append(str(rec["id"]))
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||||
|
if not pkl_path.exists():
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||||
|
|
||||||
|
# Initialize builder - non-compact without recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=False, # Store embeddings (no recompute needed)
|
||||||
|
is_compact=False, # Store full index (not pruned)
|
||||||
|
distance_metric="cosine",
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add passages - embeddings will be loaded from file
|
||||||
|
print("📚 Adding passages...")
|
||||||
|
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||||
|
|
||||||
|
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
|
||||||
|
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
print(f"✅ Non-compact index built in {build_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze index size
|
||||||
|
self._analyze_index_size(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
|
||||||
|
"""Helper to add passages with pre-computed CLIP embeddings"""
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in tqdm(f, desc="Adding passages"):
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
|
||||||
|
# Add image metadata - LEANN will handle embeddings separately
|
||||||
|
# Note: We store image metadata and caption text for searchability
|
||||||
|
# Important: ensure passage ID in metadata matches vector ID
|
||||||
|
builder.add_text(
|
||||||
|
text=passage["text"], # Image caption for searchability
|
||||||
|
metadata={**passage["metadata"], "id": passage["id"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_index_size(self, index_path: str):
|
||||||
|
"""Analyze index file sizes"""
|
||||||
|
print("📏 Analyzing index sizes...")
|
||||||
|
|
||||||
|
index_path = Path(index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.name # e.g., laion_index.leann
|
||||||
|
index_prefix = index_path.stem # e.g., laion_index
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(f"{index_prefix}.index", ".index", "core"),
|
||||||
|
(f"{index_name}.meta.json", ".meta.json", "core"),
|
||||||
|
(f"{index_name}.ids.txt", ".ids.txt", "core"),
|
||||||
|
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
|
||||||
|
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _fmt_size(bytes_val: int) -> str:
|
||||||
|
if bytes_val < 1024:
|
||||||
|
return f"{bytes_val} B"
|
||||||
|
kb = bytes_val / 1024
|
||||||
|
if kb < 1024:
|
||||||
|
return f"{kb:.1f} KB"
|
||||||
|
mb = kb / 1024
|
||||||
|
if mb < 1024:
|
||||||
|
return f"{mb:.2f} MB"
|
||||||
|
gb = mb / 1024
|
||||||
|
return f"{gb:.2f} GB"
|
||||||
|
|
||||||
|
total_index_only_mb = 0.0
|
||||||
|
total_all_mb = 0.0
|
||||||
|
for filename, label, group in files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if file_path.exists():
|
||||||
|
size_bytes = file_path.stat().st_size
|
||||||
|
print(f" {label}: {_fmt_size(size_bytes)}")
|
||||||
|
size_mb = size_bytes / (1024 * 1024)
|
||||||
|
total_all_mb += size_mb
|
||||||
|
if group == "core":
|
||||||
|
total_index_only_mb += size_mb
|
||||||
|
else:
|
||||||
|
print(f" {label}: (missing)")
|
||||||
|
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
|
||||||
|
print(f" Total (including passages): {total_all_mb:.2f} MB")
|
||||||
|
|
||||||
|
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
|
||||||
|
"""Create evaluation queries from captions"""
|
||||||
|
print(f"📝 Creating {num_queries} evaluation queries...")
|
||||||
|
|
||||||
|
# Sample random captions as queries
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(42) # For reproducibility
|
||||||
|
|
||||||
|
query_samples = random.sample(samples, min(num_queries, len(samples)))
|
||||||
|
|
||||||
|
queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
with open(queries_file, "w", encoding="utf-8") as f:
|
||||||
|
for sample in query_samples:
|
||||||
|
query = {
|
||||||
|
"id": sample["id"],
|
||||||
|
"query": sample["caption"],
|
||||||
|
"ground_truth_id": sample["id"], # For potential recall evaluation
|
||||||
|
}
|
||||||
|
f.write(json.dumps(query) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Created {len(query_samples)} evaluation queries")
|
||||||
|
return queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
|
||||||
|
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
|
||||||
|
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("🚀 Setting up LAION Multimodal Benchmark")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize setup
|
||||||
|
setup = LAIONSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Step 1: Download LAION subset
|
||||||
|
if not args.skip_download:
|
||||||
|
print("\n📦 Step 1: Download LAION subset")
|
||||||
|
samples = setup.download_laion_subset(args.num_samples)
|
||||||
|
|
||||||
|
# Step 2: Generate CLIP image embeddings
|
||||||
|
print("\n🔍 Step 2: Generate CLIP image embeddings")
|
||||||
|
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
|
||||||
|
|
||||||
|
# Step 3: Create LEANN passages (image metadata with embeddings)
|
||||||
|
print("\n📝 Step 3: Create LEANN passages")
|
||||||
|
passages_file = setup.create_leann_passages(valid_samples)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LAION dataset download")
|
||||||
|
# Load existing data
|
||||||
|
passages_file = setup.data_dir / "laion_passages.jsonl"
|
||||||
|
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not embeddings_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passages or embeddings file not found. Run without --skip-download first."
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.load(embeddings_file)
|
||||||
|
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
|
||||||
|
|
||||||
|
# Step 4: Build LEANN indexes (both compact and non-compact)
|
||||||
|
if not args.skip_build:
|
||||||
|
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
|
||||||
|
|
||||||
|
# Build compact index (production mode - small, recompute required)
|
||||||
|
compact_index_path = args.index_path
|
||||||
|
print(f"Building compact index: {compact_index_path}")
|
||||||
|
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
|
||||||
|
|
||||||
|
# Build non-compact index (comparison mode - large, fast search)
|
||||||
|
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
|
||||||
|
print(f"Building non-compact index: {non_compact_index_path}")
|
||||||
|
setup.build_non_compact_index(
|
||||||
|
passages_file, embeddings, non_compact_index_path, args.backend
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Build FAISS flat baseline
|
||||||
|
print("\n🔨 Step 5: Build FAISS flat baseline")
|
||||||
|
if not args.skip_download:
|
||||||
|
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||||
|
else:
|
||||||
|
# Load valid_samples from passages file for FAISS baseline
|
||||||
|
valid_samples = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
|
||||||
|
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||||
|
|
||||||
|
# Step 6: Create evaluation queries
|
||||||
|
print("\n📝 Step 6: Create evaluation queries")
|
||||||
|
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping index building")
|
||||||
|
baseline_path = "data/baseline/faiss_index.bin"
|
||||||
|
queries_file = setup.data_dir / "evaluation_queries.jsonl"
|
||||||
|
|
||||||
|
print("\n🎉 Setup completed successfully!")
|
||||||
|
print("📊 Summary:")
|
||||||
|
if not args.skip_download:
|
||||||
|
print(f" Downloaded samples: {len(samples)}")
|
||||||
|
print(f" Valid samples with embeddings: {len(valid_samples)}")
|
||||||
|
else:
|
||||||
|
print(f" Loaded {len(embeddings)} embeddings")
|
||||||
|
|
||||||
|
if not args.skip_build:
|
||||||
|
print(f" Compact index: {compact_index_path}")
|
||||||
|
print(f" Non-compact index: {non_compact_index_path}")
|
||||||
|
print(f" FAISS baseline: {baseline_path}")
|
||||||
|
print(f" Queries: {queries_file}")
|
||||||
|
|
||||||
|
print("\n🔧 Next steps:")
|
||||||
|
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
|
||||||
|
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print(" Skipped building indexes")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Setup interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Setup failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
301
benchmarks/llm_utils.py
Normal file
301
benchmarks/llm_utils.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
HF_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
HF_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
VLLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_qwen3_model(model_name):
|
||||||
|
"""Check if model is Qwen3"""
|
||||||
|
return "Qwen3" in model_name or "qwen3" in model_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def is_qwen_vl_model(model_name):
|
||||||
|
"""Check if model is Qwen2.5-VL"""
|
||||||
|
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_qwen3_chat_template(tokenizer, prompt):
|
||||||
|
"""Apply Qwen3 chat template with thinking enabled"""
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_thinking_answer(response):
|
||||||
|
"""Extract final answer from Qwen3 thinking model response"""
|
||||||
|
if "<think>" in response and "</think>" in response:
|
||||||
|
try:
|
||||||
|
think_end = response.index("</think>") + len("</think>")
|
||||||
|
final_answer = response[think_end:].strip()
|
||||||
|
return final_answer
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_model(model_name="Qwen/Qwen3-8B"):
|
||||||
|
"""Load HuggingFace model"""
|
||||||
|
if not HF_AVAILABLE:
|
||||||
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
|
print(f"Loading HF: {model_name}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def load_vllm_model(model_name="Qwen/Qwen3-8B"):
|
||||||
|
"""Load vLLM model"""
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
raise ImportError("vllm not available")
|
||||||
|
|
||||||
|
print(f"Loading vLLM: {model_name}")
|
||||||
|
llm = LLM(model=model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Qwen3 specific config
|
||||||
|
if is_qwen3_model(model_name):
|
||||||
|
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
|
||||||
|
max_tokens = 2048
|
||||||
|
else:
|
||||||
|
stop_tokens = None
|
||||||
|
max_tokens = 1024
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
|
||||||
|
return llm, sampling_params
|
||||||
|
|
||||||
|
|
||||||
|
def generate_hf(tokenizer, model, prompt, max_tokens=None):
|
||||||
|
"""Generate with HF - supports Qwen3 thinking models"""
|
||||||
|
model_name = getattr(model, "name_or_path", "unknown")
|
||||||
|
is_qwen3 = is_qwen3_model(model_name)
|
||||||
|
|
||||||
|
# Apply chat template for Qwen3
|
||||||
|
if is_qwen3:
|
||||||
|
prompt = apply_qwen3_chat_template(tokenizer, prompt)
|
||||||
|
max_tokens = max_tokens or 2048
|
||||||
|
else:
|
||||||
|
max_tokens = max_tokens or 1024
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
temperature=0.7,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
response = response[len(prompt) :].strip()
|
||||||
|
|
||||||
|
# Extract final answer for thinking models
|
||||||
|
if is_qwen3:
|
||||||
|
return extract_thinking_answer(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vllm(llm, sampling_params, prompt):
|
||||||
|
"""Generate with vLLM - supports Qwen3 thinking models"""
|
||||||
|
outputs = llm.generate([prompt], sampling_params)
|
||||||
|
response = outputs[0].outputs[0].text.strip()
|
||||||
|
|
||||||
|
# Extract final answer for Qwen3 thinking models
|
||||||
|
model_name = str(llm.llm_engine.model_config.model)
|
||||||
|
if is_qwen3_model(model_name):
|
||||||
|
return extract_thinking_answer(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_prompt(context, query, domain="default"):
|
||||||
|
"""Create RAG prompt"""
|
||||||
|
if domain == "emails":
|
||||||
|
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
elif domain == "finance":
|
||||||
|
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
elif domain == "multimodal":
|
||||||
|
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
else:
|
||||||
|
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
|
||||||
|
"""Simple RAG evaluation with timing"""
|
||||||
|
search_times = []
|
||||||
|
gen_times = []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Search
|
||||||
|
start = time.time()
|
||||||
|
docs = searcher.search(query, top_k=top_k, complexity=complexity)
|
||||||
|
search_time = time.time() - start
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
context = "\n\n".join([doc.text for doc in docs])
|
||||||
|
prompt = create_prompt(context, query, domain)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
response = llm_func(prompt)
|
||||||
|
gen_time = time.time() - start
|
||||||
|
|
||||||
|
search_times.append(search_time)
|
||||||
|
gen_times.append(gen_time)
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times),
|
||||||
|
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
|
||||||
|
"""Load Qwen2.5-VL multimodal model"""
|
||||||
|
if not HF_AVAILABLE:
|
||||||
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
|
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
|
||||||
|
|
||||||
|
# Fallback to specific class
|
||||||
|
try:
|
||||||
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e2:
|
||||||
|
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
|
||||||
|
"""Generate with Qwen2.5-VL multimodal model"""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
if image_path:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||||||
|
else:
|
||||||
|
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode response
|
||||||
|
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
|
||||||
|
response = processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
|
||||||
|
"""Create prompt for multimodal RAG"""
|
||||||
|
if task_type == "images":
|
||||||
|
return f"""Based on the retrieved images and their descriptions, answer the following question.
|
||||||
|
|
||||||
|
Retrieved Image Descriptions:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Question: {query}
|
||||||
|
|
||||||
|
Provide a detailed answer based on the visual content described above."""
|
||||||
|
|
||||||
|
return f"Context: {context}\nQuestion: {query}\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
|
||||||
|
"""Evaluate multimodal RAG with Qwen2.5-VL"""
|
||||||
|
search_times = []
|
||||||
|
gen_times = []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, query_item in enumerate(queries):
|
||||||
|
# Handle both string and dict formats for queries
|
||||||
|
if isinstance(query_item, dict):
|
||||||
|
query = query_item.get("query", "")
|
||||||
|
image_path = query_item.get("image_path") # Optional reference image
|
||||||
|
else:
|
||||||
|
query = str(query_item)
|
||||||
|
image_path = None
|
||||||
|
|
||||||
|
# Search
|
||||||
|
start_time = time.time()
|
||||||
|
search_results = searcher.search(query, top_k=3, complexity=complexity)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
|
||||||
|
# Prepare context from search results
|
||||||
|
context_parts = []
|
||||||
|
for result in search_results:
|
||||||
|
context_parts.append(f"- {result.text}")
|
||||||
|
context = "\n".join(context_parts)
|
||||||
|
|
||||||
|
# Generate with multimodal model
|
||||||
|
start_time = time.time()
|
||||||
|
if processor and model:
|
||||||
|
prompt = create_multimodal_prompt(context, query, context_parts)
|
||||||
|
response = generate_qwen_vl(processor, model, prompt, image_path)
|
||||||
|
else:
|
||||||
|
response = f"Context: {context}"
|
||||||
|
gen_time = time.time() - start_time
|
||||||
|
|
||||||
|
gen_times.append(gen_time)
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times),
|
||||||
|
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
@@ -2,20 +2,20 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from contextlib import contextmanager
|
from transformers import AutoModel, BitsAndBytesConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str
|
model_path: str
|
||||||
batch_sizes: List[int]
|
batch_sizes: list[int]
|
||||||
seq_length: int
|
seq_length: int
|
||||||
num_runs: int
|
num_runs: int
|
||||||
use_fp16: bool = True
|
use_fp16: bool = True
|
||||||
@@ -32,13 +32,11 @@ class GraphContainer:
|
|||||||
def __init__(self, model: nn.Module, seq_length: int):
|
def __init__(self, model: nn.Module, seq_length: int):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.graphs: Dict[int, 'GraphWrapper'] = {}
|
self.graphs: dict[int, GraphWrapper] = {}
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'GraphWrapper':
|
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||||
if batch_size not in self.graphs:
|
if batch_size not in self.graphs:
|
||||||
self.graphs[batch_size] = GraphWrapper(
|
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||||
self.model, batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[batch_size]
|
return self.graphs[batch_size]
|
||||||
|
|
||||||
|
|
||||||
@@ -55,13 +53,13 @@ class GraphWrapper:
|
|||||||
self._warmup()
|
self._warmup()
|
||||||
|
|
||||||
# Only use CUDA graphs on NVIDIA GPUs
|
# Only use CUDA graphs on NVIDIA GPUs
|
||||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'):
|
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||||
# Capture graph
|
# Capture graph
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self.graph):
|
with torch.cuda.graph(self.graph):
|
||||||
self.static_output = self.model(
|
self.static_output = self.model(
|
||||||
input_ids=self.static_input,
|
input_ids=self.static_input,
|
||||||
attention_mask=self.static_attention_mask
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
self.use_cuda_graph = True
|
self.use_cuda_graph = True
|
||||||
else:
|
else:
|
||||||
@@ -79,9 +77,7 @@ class GraphWrapper:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000, (batch_size, seq_length),
|
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||||
device=self.device,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
def _warmup(self, num_warmup: int = 3):
|
||||||
@@ -89,7 +85,7 @@ class GraphWrapper:
|
|||||||
for _ in range(num_warmup):
|
for _ in range(num_warmup):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=self.static_input,
|
input_ids=self.static_input,
|
||||||
attention_mask=self.static_attention_mask
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -133,8 +129,12 @@ class ModelOptimizer:
|
|||||||
print("- Using FP16 precision")
|
print("- Using FP16 precision")
|
||||||
|
|
||||||
# Check if using SDPA (only on CUDA)
|
# Check if using SDPA (only on CUDA)
|
||||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
@@ -142,7 +142,8 @@ class ModelOptimizer:
|
|||||||
# Flash Attention (only on CUDA)
|
# Flash Attention (only on CUDA)
|
||||||
if config.use_flash_attention and torch.cuda.is_available():
|
if config.use_flash_attention and torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attention import FlashAttention
|
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||||
|
|
||||||
print("- Flash Attention 2 available")
|
print("- Flash Attention 2 available")
|
||||||
if hasattr(model.config, "attention_mode"):
|
if hasattr(model.config, "attention_mode"):
|
||||||
model.config.attention_mode = "flash_attention_2"
|
model.config.attention_mode = "flash_attention_2"
|
||||||
@@ -153,8 +154,9 @@ class ModelOptimizer:
|
|||||||
# Memory efficient attention (only on CUDA)
|
# Memory efficient attention (only on CUDA)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from xformers.ops import memory_efficient_attention
|
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
else:
|
else:
|
||||||
@@ -220,7 +222,7 @@ class Benchmark:
|
|||||||
self.graphs = None
|
self.graphs = None
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
print(f"ERROR in benchmark initialization: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
def _load_model(self) -> nn.Module:
|
||||||
@@ -230,15 +232,17 @@ class Benchmark:
|
|||||||
# Int4 quantization using HuggingFace integration
|
# Int4 quantization using HuggingFace integration
|
||||||
if self.config.use_int4:
|
if self.config.use_int4:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||||
|
|
||||||
# 检查是否使用自定义的8bit量化
|
# Check if using custom 8bit quantization
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||||
|
|
||||||
# 加载原始模型(不使用量化配置)
|
# Load original model (without quantization config)
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# set default to half
|
# set default to half
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||||
@@ -247,52 +251,58 @@ class Benchmark:
|
|||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义替换函数
|
# Define replacement function
|
||||||
def replace_linear_with_linear8bitlt(model):
|
def replace_linear_with_linear8bitlt(model):
|
||||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||||
for name, module in list(model.named_children()):
|
for name, module in list(model.named_children()):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
# 获取原始线性层的参数
|
# Get original linear layer parameters
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
bias = module.bias is not None
|
bias = module.bias is not None
|
||||||
|
|
||||||
# 创建8bit线性层
|
# Create 8bit linear layer
|
||||||
# print size
|
# print size
|
||||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||||
new_module = bnb.nn.Linear8bitLt(
|
new_module = bnb.nn.Linear8bitLt(
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
has_fp16_weights=False
|
has_fp16_weights=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 复制权重和偏置
|
# Copy weights and bias
|
||||||
new_module.weight.data = module.weight.data
|
new_module.weight.data = module.weight.data
|
||||||
if bias:
|
if bias:
|
||||||
new_module.bias.data = module.bias.data
|
new_module.bias.data = module.bias.data
|
||||||
|
|
||||||
# 替换模块
|
# Replace module
|
||||||
setattr(model, name, new_module)
|
setattr(model, name, new_module)
|
||||||
else:
|
else:
|
||||||
# 递归处理子模块
|
# Process child modules recursively
|
||||||
replace_linear_with_linear8bitlt(module)
|
replace_linear_with_linear8bitlt(module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# 替换所有线性层
|
# Replace all linear layers
|
||||||
model = replace_linear_with_linear8bitlt(model)
|
model = replace_linear_with_linear8bitlt(model)
|
||||||
# add torch compile
|
# add torch compile
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
# 将模型移到GPU(量化发生在这里)
|
# Move model to GPU (quantization happens here)
|
||||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print("- All linear layers replaced with Linear8bitLt")
|
print("- All linear layers replaced with Linear8bitLt")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 使用原来的Int4量化方法
|
# Use original Int4 quantization method
|
||||||
print("- Using bitsandbytes for Int4 quantization")
|
print("- Using bitsandbytes for Int4 quantization")
|
||||||
|
|
||||||
# Create quantization config
|
# Create quantization config
|
||||||
@@ -302,7 +312,7 @@ class Benchmark:
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=compute_dtype,
|
bnb_4bit_compute_dtype=compute_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4"
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("- Quantization config:", quantization_config)
|
print("- Quantization config:", quantization_config)
|
||||||
@@ -312,7 +322,7 @@ class Benchmark:
|
|||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map="auto" # Let HF decide on device mapping
|
device_map="auto", # Let HF decide on device mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if model loaded successfully
|
# Check if model loaded successfully
|
||||||
@@ -324,7 +334,7 @@ class Benchmark:
|
|||||||
# Apply optimizations directly here
|
# Apply optimizations directly here
|
||||||
print("\nApplying model optimizations:")
|
print("\nApplying model optimizations:")
|
||||||
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||||
else:
|
else:
|
||||||
# Skip moving to GPU since device_map="auto" already did that
|
# Skip moving to GPU since device_map="auto" already did that
|
||||||
@@ -334,8 +344,12 @@ class Benchmark:
|
|||||||
print(f"- Using {compute_dtype} for compute dtype")
|
print(f"- Using {compute_dtype} for compute dtype")
|
||||||
|
|
||||||
# Check CUDA and SDPA
|
# Check CUDA and SDPA
|
||||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
@@ -343,8 +357,7 @@ class Benchmark:
|
|||||||
# Try xformers if available (only on CUDA)
|
# Try xformers if available (only on CUDA)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from xformers.ops import memory_efficient_attention
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
else:
|
else:
|
||||||
@@ -370,7 +383,7 @@ class Benchmark:
|
|||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map="auto"
|
device_map="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
@@ -389,6 +402,7 @@ class Benchmark:
|
|||||||
# Apply standard optimizations
|
# Apply standard optimizations
|
||||||
# set default to half
|
# set default to half
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model = ModelOptimizer.optimize(model, self.config)
|
model = ModelOptimizer.optimize(model, self.config)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@@ -403,25 +417,31 @@ class Benchmark:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR loading model: {str(e)}")
|
print(f"ERROR loading model: {e!s}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000,
|
0,
|
||||||
|
1000,
|
||||||
(batch_size, self.config.seq_length),
|
(batch_size, self.config.seq_length),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(
|
def _run_inference(
|
||||||
self,
|
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||||
input_ids: torch.Tensor,
|
) -> tuple[float, torch.Tensor]:
|
||||||
graph_wrapper: Optional[GraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
with torch.no_grad(), self.timer.timing():
|
||||||
@@ -432,7 +452,7 @@ class Benchmark:
|
|||||||
|
|
||||||
return self.timer.elapsed_time(), output
|
return self.timer.elapsed_time(), output
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
# Reset peak memory stats
|
# Reset peak memory stats
|
||||||
@@ -450,9 +470,7 @@ class Benchmark:
|
|||||||
|
|
||||||
# Get or create graph for this batch size
|
# Get or create graph for this batch size
|
||||||
graph_wrapper = (
|
graph_wrapper = (
|
||||||
self.graphs.get_or_create(batch_size)
|
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||||
if self.graphs is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
# Pre-allocate input tensor
|
||||||
@@ -604,7 +622,15 @@ def main():
|
|||||||
os.makedirs("results", exist_ok=True)
|
os.makedirs("results", exist_ok=True)
|
||||||
|
|
||||||
# Generate filename based on configuration
|
# Generate filename based on configuration
|
||||||
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32"
|
precision_type = (
|
||||||
|
"int4"
|
||||||
|
if config.use_int4
|
||||||
|
else "int8"
|
||||||
|
if config.use_int8
|
||||||
|
else "fp16"
|
||||||
|
if config.use_fp16
|
||||||
|
else "fp32"
|
||||||
|
)
|
||||||
model_name = os.path.basename(config.model_path)
|
model_name = os.path.basename(config.model_path)
|
||||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||||
|
|
||||||
@@ -612,17 +638,20 @@ def main():
|
|||||||
with open(output_file, "w") as f:
|
with open(output_file, "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
"config": {
|
||||||
"results": {str(k): v for k, v in results.items()}
|
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||||
|
},
|
||||||
|
"results": {str(k): v for k, v in results.items()},
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
indent=2
|
indent=2,
|
||||||
)
|
)
|
||||||
print(f"Results saved to {output_file}")
|
print(f"Results saved to {output_file}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Benchmark failed: {e}")
|
print(f"Benchmark failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
@@ -5,24 +5,21 @@ It correctly compares results by fetching the text content for both the new sear
|
|||||||
results and the golden standard results, making the comparison robust to ID changes.
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from leann.api import LeannSearcher, LeannBuilder
|
import numpy as np
|
||||||
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
if not data_root.exists():
|
if not data_root.exists():
|
||||||
print(f"Data directory '{data_root}' not found.")
|
print(f"Data directory '{data_root}' not found.")
|
||||||
print(
|
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
@@ -56,14 +53,14 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
|||||||
print(
|
print(
|
||||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
)
|
)
|
||||||
print("uv pip install -e '.[dev]'")
|
print("uv sync --only-group dev")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred during data download: {e}")
|
print(f"An error occurred during data download: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||||
"""Download embeddings files specifically."""
|
"""Download embeddings files specifically."""
|
||||||
embeddings_dir = data_root / "embeddings"
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
@@ -101,7 +98,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
|||||||
|
|
||||||
|
|
||||||
# --- Helper Function to get Golden Passages ---
|
# --- Helper Function to get Golden Passages ---
|
||||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||||
"""
|
"""
|
||||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
passage manager.
|
passage manager.
|
||||||
@@ -113,24 +110,20 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
|||||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
golden_texts.add(passage_data["text"])
|
golden_texts.add(passage_data["text"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(
|
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
|
||||||
)
|
|
||||||
return golden_texts
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
def load_queries(file_path: Path) -> List[str]:
|
def load_queries(file_path: Path) -> list[str]:
|
||||||
queries = []
|
queries = []
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
queries.append(data["query"])
|
queries.append(data["query"])
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
def build_index_from_embeddings(
|
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Build a LEANN index from pre-computed embeddings.
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
@@ -173,9 +166,7 @@ def build_index_from_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||||
description="Run recall evaluation on a LEANN index."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"index_path",
|
"index_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -202,26 +193,41 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Batch size for HNSW batched search (0 disables batching)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||||
|
default="ollama",
|
||||||
|
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default="qwen3:1.7b",
|
||||||
|
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
# Assumes a project structure where the script is in 'examples/'
|
# Assumes a project structure where the script is in 'benchmarks/'
|
||||||
# and data is in 'data/' at the project root.
|
# and evaluation data is in 'benchmarks/data/'.
|
||||||
project_root = Path(__file__).resolve().parent.parent
|
script_dir = Path(__file__).resolve().parent
|
||||||
data_root = project_root / "data"
|
data_root = script_dir / "data"
|
||||||
|
|
||||||
# Download data based on mode
|
# Download data based on mode
|
||||||
if args.mode == "build":
|
if args.mode == "build":
|
||||||
# For building mode, we need embeddings
|
# For building mode, we need embeddings
|
||||||
download_data_if_needed(
|
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||||
data_root, download_embeddings=False
|
|
||||||
) # Basic data first
|
|
||||||
|
|
||||||
# Auto-detect dataset type and download embeddings
|
# Auto-detect dataset type and download embeddings
|
||||||
if args.embeddings_file:
|
if args.embeddings_file:
|
||||||
@@ -262,9 +268,7 @@ def main():
|
|||||||
print(f"Index built successfully: {built_index_path}")
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
# Ask if user wants to run evaluation
|
# Ask if user wants to run evaluation
|
||||||
eval_response = (
|
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
|
||||||
)
|
|
||||||
if eval_response != "y":
|
if eval_response != "y":
|
||||||
print("Index building complete. Exiting.")
|
print("Index building complete. Exiting.")
|
||||||
return
|
return
|
||||||
@@ -293,11 +297,9 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not args.index_path:
|
if not args.index_path:
|
||||||
|
print("No indices found. The data download should have included pre-built indices.")
|
||||||
print(
|
print(
|
||||||
"No indices found. The data download should have included pre-built indices."
|
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||||
)
|
|
||||||
print(
|
|
||||||
"Please check the data/indices/ directory or provide --index-path manually."
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@@ -310,14 +312,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
# Fallback: try to infer from the index directory name
|
# Fallback: try to infer from the index directory name
|
||||||
dataset_type = Path(args.index_path).name
|
dataset_type = Path(args.index_path).name
|
||||||
print(
|
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
golden_results_file = (
|
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
print(f"INFO: Using queries file: {queries_file}")
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
@@ -327,7 +325,7 @@ def main():
|
|||||||
searcher = LeannSearcher(args.index_path)
|
searcher = LeannSearcher(args.index_path)
|
||||||
queries = load_queries(queries_file)
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
with open(golden_results_file, "r") as f:
|
with open(golden_results_file) as f:
|
||||||
golden_results_data = json.load(f)
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
num_eval_queries = min(args.num_queries, len(queries))
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
@@ -340,10 +338,23 @@ def main():
|
|||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(
|
new_results = searcher.search(
|
||||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
)
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||||
|
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||||
|
answer = chat.ask(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
print(f"Answer: {answer}")
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
@@ -1,26 +1,27 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
# Add MLX imports
|
# Add MLX imports
|
||||||
try:
|
try:
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_lm.utils import load
|
from mlx_lm.utils import load
|
||||||
|
|
||||||
MLX_AVAILABLE = True
|
MLX_AVAILABLE = True
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||||
MLX_AVAILABLE = False
|
MLX_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever"
|
model_path: str = "facebook/contriever-msmarco"
|
||||||
batch_sizes: List[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
use_fp16: bool = True
|
use_fp16: bool = True
|
||||||
@@ -33,7 +34,8 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
"""MLX-specific benchmark for embedding models"""
|
"""MLX-specific benchmark for embedding models"""
|
||||||
@@ -55,11 +57,7 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int):
|
def _create_random_batch(self, batch_size: int):
|
||||||
"""Create random input batches for MLX testing - same as PyTorch"""
|
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||||
return torch.randint(
|
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||||
0, 1000,
|
|
||||||
(batch_size, self.config.seq_length),
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
"""Run MLX inference with same input as PyTorch"""
|
"""Run MLX inference with same input as PyTorch"""
|
||||||
@@ -82,12 +80,12 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"MLX inference error: {e}")
|
print(f"MLX inference error: {e}")
|
||||||
return float('inf')
|
return float("inf")
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
"""Run the MLX benchmark across all batch sizes"""
|
"""Run the MLX benchmark across all batch sizes"""
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
@@ -111,10 +109,10 @@ class MLXBenchmark:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Run benchmark
|
# Run benchmark
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||||
try:
|
try:
|
||||||
elapsed_time = self._run_inference(input_ids)
|
elapsed_time = self._run_inference(input_ids)
|
||||||
if elapsed_time != float('inf'):
|
if elapsed_time != float("inf"):
|
||||||
times.append(elapsed_time)
|
times.append(elapsed_time)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during MLX inference: {e}")
|
print(f"Error during MLX inference: {e}")
|
||||||
@@ -145,16 +143,22 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
def __init__(self, config: BenchmarkConfig):
|
def __init__(self, config: BenchmarkConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
self.device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
self.model = self._load_model()
|
self.model = self._load_model()
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
def _load_model(self) -> nn.Module:
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
print(f"Loading model from {self.config.model_path}...")
|
||||||
|
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
model = AutoModel.from_pretrained(self.config.model_path)
|
||||||
if self.config.use_fp16:
|
if self.config.use_fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@@ -166,23 +170,30 @@ class Benchmark:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000,
|
0,
|
||||||
|
1000,
|
||||||
(batch_size, self.config.seq_length),
|
(batch_size, self.config.seq_length),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
# print shape of input_ids and attention_mask
|
||||||
|
print(f"input_ids shape: {input_ids.shape}")
|
||||||
|
print(f"attention_mask shape: {attention_mask.shape}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -194,7 +205,7 @@ class Benchmark:
|
|||||||
|
|
||||||
input_ids = self._create_random_batch(batch_size)
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||||
try:
|
try:
|
||||||
elapsed_time = self._run_inference(input_ids)
|
elapsed_time = self._run_inference(input_ids)
|
||||||
times.append(elapsed_time)
|
times.append(elapsed_time)
|
||||||
@@ -228,6 +239,7 @@ class Benchmark:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark():
|
def run_benchmark():
|
||||||
"""Main function to run the benchmark with optimized parameters."""
|
"""Main function to run the benchmark with optimized parameters."""
|
||||||
config = BenchmarkConfig()
|
config = BenchmarkConfig()
|
||||||
@@ -242,16 +254,13 @@ def run_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": max_throughput,
|
"max_throughput": max_throughput,
|
||||||
"avg_throughput": avg_throughput,
|
"avg_throughput": avg_throughput,
|
||||||
"results": results
|
"results": results,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Benchmark failed: {e}")
|
print(f"Benchmark failed: {e}")
|
||||||
return {
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
"max_throughput": 0.0,
|
|
||||||
"avg_throughput": 0.0,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_mlx_benchmark():
|
def run_mlx_benchmark():
|
||||||
"""Run MLX-specific benchmark"""
|
"""Run MLX-specific benchmark"""
|
||||||
@@ -260,13 +269,10 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": 0.0,
|
"max_throughput": 0.0,
|
||||||
"avg_throughput": 0.0,
|
"avg_throughput": 0.0,
|
||||||
"error": "MLX not available"
|
"error": "MLX not available",
|
||||||
}
|
}
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||||
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
|
|
||||||
use_mlx=True
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
benchmark = MLXBenchmark(config)
|
benchmark = MLXBenchmark(config)
|
||||||
@@ -276,7 +282,7 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": 0.0,
|
"max_throughput": 0.0,
|
||||||
"avg_throughput": 0.0,
|
"avg_throughput": 0.0,
|
||||||
"error": "No valid results"
|
"error": "No valid results",
|
||||||
}
|
}
|
||||||
|
|
||||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
@@ -285,16 +291,13 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": max_throughput,
|
"max_throughput": max_throughput,
|
||||||
"avg_throughput": avg_throughput,
|
"avg_throughput": avg_throughput,
|
||||||
"results": results
|
"results": results,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"MLX benchmark failed: {e}")
|
print(f"MLX benchmark failed: {e}")
|
||||||
return {
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
"max_throughput": 0.0,
|
|
||||||
"avg_throughput": 0.0,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("=== PyTorch Benchmark ===")
|
print("=== PyTorch Benchmark ===")
|
||||||
@@ -308,7 +311,7 @@ if __name__ == "__main__":
|
|||||||
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
|
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||||
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
|
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||||
print(f"\n=== Comparison ===")
|
print("\n=== Comparison ===")
|
||||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||||
82
data/.gitattributes
vendored
82
data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
|||||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.model filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
||||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - uncompressed
|
|
||||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - compressed
|
|
||||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - uncompressed
|
|
||||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.png filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - compressed
|
|
||||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Video files - compressed
|
|
||||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
|
|||||||
including how to make donations to the Project Gutenberg Literary
|
including how to make donations to the Project Gutenberg Literary
|
||||||
Archive Foundation, how to help produce our new eBooks, and how to
|
Archive Foundation, how to help produce our new eBooks, and how to
|
||||||
subscribe to our email newsletter to hear about new eBooks.
|
subscribe to our email newsletter to hear about new eBooks.
|
||||||
|
|
||||||
|
|
||||||
105
demo.ipynb
105
demo.ipynb
@@ -1,37 +1,116 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Quick Start \n",
|
||||||
|
"\n",
|
||||||
|
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||||
|
"\n",
|
||||||
|
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
"# install this if you are using colab\n",
|
||||||
|
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||||
|
"! uv pip install leann --no-deps\n",
|
||||||
|
"# For Colab environment, we need to set some environment variables\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||||
|
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build the index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannBuilder\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 1. Build the index (no embeddings stored!)\n",
|
|
||||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
"builder.add_text(\n",
|
||||||
|
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||||
|
")\n",
|
||||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
"builder.build_index(\"knowledge.leann\")\n",
|
"builder.build_index(INDEX_PATH)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Search with real-time embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannSearcher\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 2. Search with real-time embeddings\n",
|
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
|
||||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
|
"results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chat with LEANN using retrieved results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannChat\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 3. Chat with LEANN using retrieved results\n",
|
|
||||||
"llm_config = {\n",
|
"llm_config = {\n",
|
||||||
" \"type\": \"ollama\",\n",
|
" \"type\": \"hf\",\n",
|
||||||
" \"model\": \"llama3.2:1b\"\n",
|
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||||
"response = chat.ask(\n",
|
"response = chat.ask(\n",
|
||||||
" \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
|
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||||
" top_k=2,\n",
|
" top_k=2,\n",
|
||||||
")"
|
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||||
|
")\n",
|
||||||
|
"response"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
223
docs/CONTRIBUTING.md
Normal file
223
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# 🤝 Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
|
## Ways to Contribute
|
||||||
|
|
||||||
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
|
## 🚀 Development Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. **Install uv** (fast Python package installer):
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Clone the repository**:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||||
|
cd LEANN-RAG
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Install system dependencies**:
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ubuntu/Debian:**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||||
|
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Build from source**:
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
|
|
||||||
|
# Ubuntu/Debian
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔨 Pre-commit Hooks
|
||||||
|
|
||||||
|
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||||
|
|
||||||
|
### Setup Pre-commit
|
||||||
|
|
||||||
|
1. **Install pre-commit tools**:
|
||||||
|
```bash
|
||||||
|
uv sync lint
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install the git hooks**:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Run pre-commit manually** (optional):
|
||||||
|
```bash
|
||||||
|
uv run pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-commit Checks
|
||||||
|
|
||||||
|
Our pre-commit configuration includes:
|
||||||
|
- **Trailing whitespace removal**
|
||||||
|
- **End-of-file fixing**
|
||||||
|
- **YAML validation**
|
||||||
|
- **Large file prevention**
|
||||||
|
- **Merge conflict detection**
|
||||||
|
- **Debug statement detection**
|
||||||
|
- **Code formatting with ruff**
|
||||||
|
- **Code linting with ruff**
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install test tools only (no project runtime)
|
||||||
|
uv sync --group test
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
uv run pytest test/test_filename.py
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
uv run pytest --cov=leann
|
||||||
|
```
|
||||||
|
|
||||||
|
### Writing Tests
|
||||||
|
|
||||||
|
- Place tests in the `test/` directory
|
||||||
|
- Follow the naming convention `test_*.py`
|
||||||
|
- Use descriptive test names that explain what's being tested
|
||||||
|
- Include both positive and negative test cases
|
||||||
|
|
||||||
|
## 📝 Code Style
|
||||||
|
|
||||||
|
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||||
|
|
||||||
|
### Format Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Format all files
|
||||||
|
ruff format
|
||||||
|
|
||||||
|
# Check formatting without changing files
|
||||||
|
ruff format --check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lint Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run linter with auto-fix
|
||||||
|
ruff check --fix
|
||||||
|
|
||||||
|
# Just check without fixing
|
||||||
|
ruff check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Style Guidelines
|
||||||
|
|
||||||
|
- Follow PEP 8 conventions
|
||||||
|
- Use descriptive variable names
|
||||||
|
- Add type hints where appropriate
|
||||||
|
- Write docstrings for all public functions and classes
|
||||||
|
- Keep functions focused and single-purpose
|
||||||
|
|
||||||
|
## 🚦 CI/CD
|
||||||
|
|
||||||
|
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||||
|
|
||||||
|
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||||
|
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||||
|
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||||
|
4. **Wheel building**: Ensures packages can be built and distributed
|
||||||
|
|
||||||
|
### CI Commands
|
||||||
|
|
||||||
|
The CI uses the same commands as pre-commit to ensure consistency:
|
||||||
|
```bash
|
||||||
|
# Linting
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
# Format checking
|
||||||
|
ruff format --check .
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure your code passes these checks locally before pushing!
|
||||||
|
|
||||||
|
## 🔄 Pull Request Process
|
||||||
|
|
||||||
|
1. **Fork the repository** and create your branch from `main`:
|
||||||
|
```bash
|
||||||
|
git checkout -b feature/your-feature-name
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Make your changes**:
|
||||||
|
- Write clean, documented code
|
||||||
|
- Add tests for new functionality
|
||||||
|
- Update documentation as needed
|
||||||
|
|
||||||
|
3. **Run pre-commit checks**:
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Test your changes**:
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Commit with descriptive messages**:
|
||||||
|
```bash
|
||||||
|
git commit -m "feat: add new search algorithm"
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||||
|
- `feat:` for new features
|
||||||
|
- `fix:` for bug fixes
|
||||||
|
- `docs:` for documentation changes
|
||||||
|
- `test:` for test additions/changes
|
||||||
|
- `refactor:` for code refactoring
|
||||||
|
- `perf:` for performance improvements
|
||||||
|
|
||||||
|
6. **Push and create a pull request**:
|
||||||
|
- Provide a clear description of your changes
|
||||||
|
- Reference any related issues
|
||||||
|
- Include examples or screenshots if applicable
|
||||||
|
|
||||||
|
## 📚 Documentation
|
||||||
|
|
||||||
|
When adding new features or making significant changes:
|
||||||
|
|
||||||
|
1. Update relevant documentation in `/docs`
|
||||||
|
2. Add docstrings to new functions/classes
|
||||||
|
3. Update README.md if needed
|
||||||
|
4. Include usage examples
|
||||||
|
|
||||||
|
## 🤔 Getting Help
|
||||||
|
|
||||||
|
- **Discord**: Join our community for discussions
|
||||||
|
- **Issues**: Check existing issues or create a new one
|
||||||
|
- **Discussions**: For general questions and ideas
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Release Guide
|
||||||
|
|
||||||
|
## Setup (One-time)
|
||||||
|
|
||||||
|
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||||
|
1. Get token: https://pypi.org/manage/account/token/
|
||||||
|
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||||
|
|
||||||
|
## Release (One-click)
|
||||||
|
|
||||||
|
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||||
|
2. Click "Run workflow"
|
||||||
|
3. Enter version: `0.1.2`
|
||||||
|
4. Click green "Run workflow" button
|
||||||
|
|
||||||
|
That's it! The workflow will automatically:
|
||||||
|
- ✅ Update version in all packages
|
||||||
|
- ✅ Build all packages
|
||||||
|
- ✅ Publish to PyPI
|
||||||
|
- ✅ Create GitHub tag and release
|
||||||
|
|
||||||
|
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||||
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Thinking Budget Feature Implementation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
||||||
|
|
||||||
|
## Feature Description
|
||||||
|
|
||||||
|
The thinking budget feature provides three levels of computational effort for reasoning models:
|
||||||
|
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
||||||
|
- **`medium`**: Balanced speed and reasoning depth
|
||||||
|
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### 1. Command Line Interface
|
||||||
|
|
||||||
|
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LEANN CLI
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# RAG Examples
|
||||||
|
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. LLM Backend Support
|
||||||
|
|
||||||
|
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
# Handle thinking budget for reasoning models
|
||||||
|
options = kwargs.copy()
|
||||||
|
thinking_budget = kwargs.get("thinking_budget")
|
||||||
|
if thinking_budget:
|
||||||
|
options.pop("thinking_budget", None)
|
||||||
|
if thinking_budget in ["low", "medium", "high"]:
|
||||||
|
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||||
|
```
|
||||||
|
|
||||||
|
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
||||||
|
|
||||||
|
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
# Handle thinking budget for reasoning models
|
||||||
|
thinking_budget = kwargs.get("thinking_budget")
|
||||||
|
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||||
|
# Check if this is an o-series model
|
||||||
|
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||||
|
if any(model in self.model for model in o_series_models):
|
||||||
|
params["reasoning_effort"] = thinking_budget
|
||||||
|
```
|
||||||
|
|
||||||
|
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
||||||
|
|
||||||
|
### 3. Parameter Propagation
|
||||||
|
|
||||||
|
The thinking budget parameter is properly propagated through the LEANN architecture:
|
||||||
|
|
||||||
|
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
||||||
|
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
||||||
|
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
||||||
|
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
||||||
|
|
||||||
|
## Files Modified
|
||||||
|
|
||||||
|
### Core Implementation
|
||||||
|
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
||||||
|
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
||||||
|
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- `README.md`: Added thinking budget parameter to usage examples
|
||||||
|
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
```bash
|
||||||
|
# High reasoning effort for complex questions
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# Medium reasoning for balanced performance
|
||||||
|
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
||||||
|
|
||||||
|
# Low reasoning for fast responses
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
||||||
|
```
|
||||||
|
|
||||||
|
### RAG Examples
|
||||||
|
```bash
|
||||||
|
# Email RAG with high reasoning
|
||||||
|
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# Document RAG with medium reasoning
|
||||||
|
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
### Ollama Models
|
||||||
|
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
||||||
|
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
||||||
|
|
||||||
|
### OpenAI Models
|
||||||
|
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
||||||
|
- **GPT-OSS models**: Models that support reasoning capabilities
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
The implementation includes comprehensive testing:
|
||||||
|
- Parameter handling verification
|
||||||
|
- Backend-specific API format validation
|
||||||
|
- CLI argument parsing tests
|
||||||
|
- Integration with existing LEANN architecture
|
||||||
143
docs/ast_chunking_guide.md
Normal file
143
docs/ast_chunking_guide.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# AST-Aware Code chunking guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Enable AST chunking for mixed content (code + docs)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
||||||
|
|
||||||
|
# Specialized code repository indexing
|
||||||
|
python -m apps.code_rag --repo-dir ./my_codebase
|
||||||
|
|
||||||
|
# Global CLI with AST support
|
||||||
|
leann build my-code-index --docs ./src --use-ast-chunking
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install LEANN with AST chunking support
|
||||||
|
uv pip install -e "."
|
||||||
|
```
|
||||||
|
|
||||||
|
#### For normal users (PyPI install)
|
||||||
|
- Use `pip install leann` or `uv pip install leann`.
|
||||||
|
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||||
|
|
||||||
|
#### For developers (from source, editable)
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||||
|
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||||
|
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### When to Use AST Chunking
|
||||||
|
|
||||||
|
✅ **Recommended for:**
|
||||||
|
- Code repositories with multiple languages
|
||||||
|
- Mixed documentation and code content
|
||||||
|
- Complex codebases with deep function/class hierarchies
|
||||||
|
- When working with Claude Code for code assistance
|
||||||
|
|
||||||
|
❌ **Not recommended for:**
|
||||||
|
- Pure text documents
|
||||||
|
- Very large files (>1MB)
|
||||||
|
- Languages not supported by tree-sitter
|
||||||
|
|
||||||
|
### Optimal Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Recommended settings for most codebases
|
||||||
|
python -m apps.code_rag \
|
||||||
|
--repo-dir ./src \
|
||||||
|
--ast-chunk-size 768 \
|
||||||
|
--ast-chunk-overlap 96 \
|
||||||
|
--exclude-dirs .git __pycache__ node_modules build dist
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Languages
|
||||||
|
|
||||||
|
| Extension | Language | Status |
|
||||||
|
|-----------|----------|--------|
|
||||||
|
| `.py` | Python | ✅ Full support |
|
||||||
|
| `.java` | Java | ✅ Full support |
|
||||||
|
| `.cs` | C# | ✅ Full support |
|
||||||
|
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
||||||
|
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
||||||
|
|
||||||
|
## Integration Examples
|
||||||
|
|
||||||
|
### Document RAG with Code Support
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Enable code chunking in document RAG
|
||||||
|
python -m apps.document_rag \
|
||||||
|
--enable-code-chunking \
|
||||||
|
--data-dir ./project \
|
||||||
|
--query "How does authentication work in the codebase?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claude Code Integration
|
||||||
|
|
||||||
|
When using with Claude Code MCP server, AST chunking provides better context for:
|
||||||
|
- Code completion and suggestions
|
||||||
|
- Bug analysis and debugging
|
||||||
|
- Architecture understanding
|
||||||
|
- Refactoring assistance
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Fallback to Traditional Chunking**
|
||||||
|
- Normal behavior for unsupported languages
|
||||||
|
- Check logs for specific language support
|
||||||
|
|
||||||
|
2. **Performance with Large Files**
|
||||||
|
- Adjust `--max-file-size` parameter
|
||||||
|
- Use `--exclude-dirs` to skip unnecessary directories
|
||||||
|
|
||||||
|
3. **Quality Issues**
|
||||||
|
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
||||||
|
- Adjust overlap for better context preservation
|
||||||
|
|
||||||
|
### Debug Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LEANN_LOG_LEVEL=DEBUG
|
||||||
|
python -m apps.code_rag --repo-dir ./my_code
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration from Traditional Chunking
|
||||||
|
|
||||||
|
Existing workflows continue to work without changes. To enable AST chunking:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Before
|
||||||
|
python -m apps.document_rag --chunk-size 256
|
||||||
|
|
||||||
|
# After (maintains traditional chunking for non-code files)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
||||||
|
```
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
||||||
|
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
||||||
|
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
||||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Comparison between Sentence Transformers and OpenAI embeddings
|
||||||
|
|
||||||
|
This example shows how different embedding models handle complex queries
|
||||||
|
and demonstrates the differences between local and API-based embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
|
# OpenAI API key should be set as environment variable
|
||||||
|
# export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||||
|
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||||
|
|
||||||
|
# Two queries with same intent but different wording
|
||||||
|
query1 = "Tell me my browser history about some conference i often visit"
|
||||||
|
query2 = "browser history about conference I often visit"
|
||||||
|
|
||||||
|
texts = [query1, query2, conference_text, browser_text]
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
return np.dot(a, b) # Already normalized
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_embeddings(embeddings, model_name):
|
||||||
|
print(f"\n=== {model_name} Results ===")
|
||||||
|
|
||||||
|
# Results for Query 1
|
||||||
|
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||||
|
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||||
|
|
||||||
|
print(f"Query 1: '{query1}'")
|
||||||
|
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Results for Query 2
|
||||||
|
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||||
|
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||||
|
|
||||||
|
print(f"\nQuery 2: '{query2}'")
|
||||||
|
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Show the impact
|
||||||
|
print("\n=== Impact Analysis ===")
|
||||||
|
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||||
|
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||||
|
|
||||||
|
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||||
|
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||||
|
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||||
|
print("✅ STABLE: Conference remains winner in both queries")
|
||||||
|
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||||
|
print("✅ STABLE: Browser remains winner in both queries")
|
||||||
|
else:
|
||||||
|
print("🔄 MIXED: Results vary between queries")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query1_conf": sim1_conf,
|
||||||
|
"query1_browser": sim1_browser,
|
||||||
|
"query2_conf": sim2_conf,
|
||||||
|
"query2_browser": sim2_browser,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Test Sentence Transformers
|
||||||
|
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||||
|
try:
|
||||||
|
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||||
|
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Sentence Transformers failed: {e}")
|
||||||
|
st_results = None
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing OpenAI (text-embedding-3-small)...")
|
||||||
|
try:
|
||||||
|
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||||
|
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ OpenAI failed: {e}")
|
||||||
|
openai_results = None
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if st_results and openai_results:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("=== COMPARISON SUMMARY ===")
|
||||||
459
docs/configuration-guide.md
Normal file
459
docs/configuration-guide.md
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
# LEANN Configuration Guide
|
||||||
|
|
||||||
|
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
||||||
|
|
||||||
|
## Getting Started: Simple is Better
|
||||||
|
|
||||||
|
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
||||||
|
|
||||||
|
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
||||||
|
```bash
|
||||||
|
python -m apps.document_rag --query "What techniques does LEANN use?"
|
||||||
|
```
|
||||||
|
|
||||||
|
**For other data sources**: Limit the dataset size for quick testing
|
||||||
|
```bash
|
||||||
|
# WeChat: Test with recent messages only
|
||||||
|
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
||||||
|
|
||||||
|
# Browser history: Last few days
|
||||||
|
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
||||||
|
|
||||||
|
# Email: Recent inbox
|
||||||
|
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
||||||
|
```
|
||||||
|
|
||||||
|
Once validated, scale up gradually:
|
||||||
|
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
||||||
|
- This helps identify issues early before committing to long processing times
|
||||||
|
|
||||||
|
## Embedding Model Selection: Understanding the Trade-offs
|
||||||
|
|
||||||
|
Based on our experience developing LEANN, embedding models fall into three categories:
|
||||||
|
|
||||||
|
### Small Models (< 100M parameters)
|
||||||
|
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
||||||
|
- **Pros**: Lightweight, fast for both indexing and inference
|
||||||
|
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
||||||
|
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
||||||
|
|
||||||
|
### Medium Models (100M-500M parameters)
|
||||||
|
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
||||||
|
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
||||||
|
- **Cons**: Requires more compute than small models
|
||||||
|
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
||||||
|
|
||||||
|
### Large Models (500M+ parameters)
|
||||||
|
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
||||||
|
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
||||||
|
- **Cons**: Slower inference, longer index build times
|
||||||
|
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
||||||
|
|
||||||
|
### Quick Start: Cloud and Local Embedding Options
|
||||||
|
|
||||||
|
**OpenAI Embeddings (Fastest Setup)**
|
||||||
|
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
|
||||||
|
```bash
|
||||||
|
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||||
|
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ollama Embeddings (Privacy-Focused)**
|
||||||
|
For local embeddings with complete privacy:
|
||||||
|
```bash
|
||||||
|
# First, pull an embedding model
|
||||||
|
ollama pull nomic-embed-text
|
||||||
|
|
||||||
|
# Use Ollama embeddings
|
||||||
|
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
||||||
|
|
||||||
|
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
||||||
|
- **Pros**: No local compute needed, consistently fast, high quality
|
||||||
|
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
|
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
||||||
|
|
||||||
|
**Local Embeddings**
|
||||||
|
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
||||||
|
- **Cons**: Slower than cloud APIs, requires local compute resources
|
||||||
|
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Local & Remote Inference Endpoints
|
||||||
|
|
||||||
|
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||||
|
|
||||||
|
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||||
|
|
||||||
|
### One-Time Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||||
|
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||||
|
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||||
|
|
||||||
|
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||||
|
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||||
|
```
|
||||||
|
|
||||||
|
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||||
|
|
||||||
|
### Passing Hosts Per Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index with a remote embedding server
|
||||||
|
leann build my-notes \
|
||||||
|
--docs ./notes \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||||
|
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||||
|
--embedding-api-key local-dev-key
|
||||||
|
|
||||||
|
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm openai \
|
||||||
|
--llm-model qwen3-8b \
|
||||||
|
--api-base http://localhost:1234/v1 \
|
||||||
|
--api-key local-dev-key
|
||||||
|
|
||||||
|
# Query an Ollama instance running on another box
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm ollama \
|
||||||
|
--llm-model qwen3:14b \
|
||||||
|
--host http://192.168.1.101:11434
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||||
|
|
||||||
|
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||||
|
- Configure router or cloud provider port forwarding.
|
||||||
|
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||||
|
|
||||||
|
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||||
|
|
||||||
|
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||||
|
|
||||||
|
### Python API Usage
|
||||||
|
|
||||||
|
You can pass the same configuration from Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://192.168.1.50:1234/v1",
|
||||||
|
"api_key": "local-dev-key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-notes", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
|
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
||||||
|
- Full recomputation required
|
||||||
|
- High memory usage during build phase
|
||||||
|
- Excellent recall (95%+)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Optimal for most use cases
|
||||||
|
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
||||||
|
```
|
||||||
|
|
||||||
|
### DiskANN
|
||||||
|
**Best for**: Large datasets, especially when you want `recompute=True`.
|
||||||
|
|
||||||
|
**Key advantages:**
|
||||||
|
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
|
||||||
|
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
|
||||||
|
- **Better scaling**: Designed for 100k+ documents
|
||||||
|
|
||||||
|
**Recompute behavior:**
|
||||||
|
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
|
||||||
|
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Recommended for most use cases
|
||||||
|
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||||
|
```
|
||||||
|
|
||||||
|
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||||
|
|
||||||
|
## LLM Selection: Engine and Model Comparison
|
||||||
|
|
||||||
|
### LLM Engines
|
||||||
|
|
||||||
|
**OpenAI** (`--llm openai`)
|
||||||
|
- **Pros**: Best quality, consistent performance, no local resources needed
|
||||||
|
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
||||||
|
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
||||||
|
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
||||||
|
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
||||||
|
|
||||||
|
**Ollama** (`--llm ollama`)
|
||||||
|
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
||||||
|
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
||||||
|
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
||||||
|
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
||||||
|
|
||||||
|
**HuggingFace** (`--llm hf`)
|
||||||
|
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
||||||
|
- **Cons**: More complex initial setup
|
||||||
|
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
||||||
|
|
||||||
|
## Parameter Tuning Guide
|
||||||
|
|
||||||
|
### Search Complexity Parameters
|
||||||
|
|
||||||
|
**`--build-complexity`** (index building)
|
||||||
|
- Controls thoroughness during index construction
|
||||||
|
- Higher = better recall but slower build
|
||||||
|
- Recommendations:
|
||||||
|
- 32: Quick prototyping
|
||||||
|
- 64: Balanced (default)
|
||||||
|
- 128: Production systems
|
||||||
|
- 256: Maximum quality
|
||||||
|
|
||||||
|
**`--search-complexity`** (query time)
|
||||||
|
- Controls search thoroughness
|
||||||
|
- Higher = better results but slower
|
||||||
|
- Recommendations:
|
||||||
|
- 16: Fast/Interactive search
|
||||||
|
- 32: High quality with diversity
|
||||||
|
- 64+: Maximum accuracy
|
||||||
|
|
||||||
|
### Top-K Selection
|
||||||
|
|
||||||
|
**`--top-k`** (number of retrieved chunks)
|
||||||
|
- More chunks = better context but slower LLM processing
|
||||||
|
- Should be always smaller than `--search-complexity`
|
||||||
|
- Guidelines:
|
||||||
|
- 10-20: General questions (default: 20)
|
||||||
|
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
||||||
|
|
||||||
|
**Trade-off formula**:
|
||||||
|
- Retrieval time ∝ log(n) × search_complexity
|
||||||
|
- LLM processing time ∝ top_k × chunk_size
|
||||||
|
- Total context = top_k × chunk_size tokens
|
||||||
|
|
||||||
|
### Thinking Budget for Reasoning Models
|
||||||
|
|
||||||
|
**`--thinking-budget`** (reasoning effort level)
|
||||||
|
- Controls the computational effort for reasoning models
|
||||||
|
- Options: `low`, `medium`, `high`
|
||||||
|
- Guidelines:
|
||||||
|
- `low`: Fast responses, basic reasoning (default for simple queries)
|
||||||
|
- `medium`: Balanced speed and reasoning depth
|
||||||
|
- `high`: Maximum reasoning effort, best for complex analytical questions
|
||||||
|
- **Supported Models**:
|
||||||
|
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
||||||
|
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
||||||
|
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
||||||
|
- **Example**: `--thinking-budget high` for complex analytical questions
|
||||||
|
|
||||||
|
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
||||||
|
|
||||||
|
**💡 Quick Examples:**
|
||||||
|
```bash
|
||||||
|
# OpenAI o-series reasoning model
|
||||||
|
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||||
|
--index-dir hnswbuild --backend hnsw \
|
||||||
|
--llm openai --llm-model o3 --thinking-budget medium
|
||||||
|
|
||||||
|
# Ollama reasoning model
|
||||||
|
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||||
|
--index-dir hnswbuild --backend hnsw \
|
||||||
|
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
```
|
||||||
|
|
||||||
|
### Graph Degree (HNSW/DiskANN)
|
||||||
|
|
||||||
|
**`--graph-degree`**
|
||||||
|
- Number of connections per node in the graph
|
||||||
|
- Higher = better recall but more memory
|
||||||
|
- HNSW: 16-32 (default: 32)
|
||||||
|
- DiskANN: 32-128 (default: 64)
|
||||||
|
|
||||||
|
|
||||||
|
## Performance Optimization Checklist
|
||||||
|
|
||||||
|
### If Embedding is Too Slow
|
||||||
|
|
||||||
|
1. **Switch to smaller model**:
|
||||||
|
```bash
|
||||||
|
# From large model
|
||||||
|
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
||||||
|
# To small model
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Limit dataset size for testing**:
|
||||||
|
```bash
|
||||||
|
--max-items 1000 # Process first 1k items only
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||||
|
```bash
|
||||||
|
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
||||||
|
```
|
||||||
|
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
||||||
|
|
||||||
|
4. **Use Ollama**
|
||||||
|
```bash
|
||||||
|
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||||
|
```
|
||||||
|
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
||||||
|
### If Search Quality is Poor
|
||||||
|
|
||||||
|
1. **Increase retrieval count**:
|
||||||
|
```bash
|
||||||
|
--top-k 30 # Retrieve more candidates
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Upgrade embedding model**:
|
||||||
|
```bash
|
||||||
|
# For English
|
||||||
|
--embedding-model BAAI/bge-base-en-v1.5
|
||||||
|
# For multilingual
|
||||||
|
--embedding-model intfloat/multilingual-e5-large
|
||||||
|
```
|
||||||
|
|
||||||
|
## Understanding the Trade-offs
|
||||||
|
|
||||||
|
Every configuration choice involves trade-offs:
|
||||||
|
|
||||||
|
| Factor | Small/Fast | Large/Quality |
|
||||||
|
|--------|------------|---------------|
|
||||||
|
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
||||||
|
| Chunk Size | 512 tokens | 128 tokens |
|
||||||
|
| Index Type | HNSW | DiskANN |
|
||||||
|
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
||||||
|
|
||||||
|
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||||
|
|
||||||
|
## Low-resource setups
|
||||||
|
|
||||||
|
If you don’t have a local GPU or builds/searches are too slow, use one or more of the options below.
|
||||||
|
|
||||||
|
### 1) Use OpenAI embeddings (no local compute)
|
||||||
|
|
||||||
|
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export OPENAI_API_KEY=sk-...
|
||||||
|
|
||||||
|
# Build with OpenAI embeddings
|
||||||
|
leann build my-index \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-3-small
|
||||||
|
|
||||||
|
# Search with OpenAI embeddings (recompute at query time)
|
||||||
|
leann search my-index "your query" \
|
||||||
|
--recompute
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||||
|
|
||||||
|
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# One-time: install and configure SkyPilot
|
||||||
|
pip install skypilot
|
||||||
|
|
||||||
|
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
|
||||||
|
sky launch -c leann-gpu sky/leann-build.yaml
|
||||||
|
|
||||||
|
# Override parameters via -e key=value (optional)
|
||||||
|
sky launch -c leann-gpu sky/leann-build.yaml \
|
||||||
|
-e index_name=my-index \
|
||||||
|
-e backend=hnsw \
|
||||||
|
-e embedding_mode=sentence-transformers \
|
||||||
|
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
|
||||||
|
|
||||||
|
# Copy the built index back to your local .leann (use rsync)
|
||||||
|
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3) Disable recomputation to trade storage for speed
|
||||||
|
|
||||||
|
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build without recomputation (HNSW requires non-compact in this mode)
|
||||||
|
leann build my-index --no-recompute --no-compact
|
||||||
|
|
||||||
|
# Search without recomputation
|
||||||
|
leann search my-index "your query" --no-recompute
|
||||||
|
```
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Extreme low latency requirements (high QPS, interactive assistants)
|
||||||
|
- Read-heavy workloads where storage is cheaper than latency
|
||||||
|
- No always-available GPU
|
||||||
|
|
||||||
|
Constraints:
|
||||||
|
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
|
||||||
|
- DiskANN: supported; `--no-recompute` skips selective recompute during search
|
||||||
|
|
||||||
|
Storage impact:
|
||||||
|
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
|
||||||
|
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
|
||||||
|
|
||||||
|
Converting an existing index (rebuild required):
|
||||||
|
```bash
|
||||||
|
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
|
||||||
|
leann build my-index --force --no-recompute --no-compact
|
||||||
|
```
|
||||||
|
|
||||||
|
Python API usage:
|
||||||
|
```python
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("/path/to/my-index.leann")
|
||||||
|
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Trade-offs:
|
||||||
|
- Lower latency and fewer network hops at query time
|
||||||
|
- Significantly higher storage (10–100× vs selective recomputation)
|
||||||
|
- Slightly larger memory footprint during build and search
|
||||||
|
|
||||||
|
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
|
||||||
|
|
||||||
|
- HNSW
|
||||||
|
|
||||||
|
```text
|
||||||
|
recompute=True: search_time=0.818s, size=1.1MB
|
||||||
|
recompute=False: search_time=0.012s, size=16.6MB
|
||||||
|
```
|
||||||
|
|
||||||
|
- DiskANN
|
||||||
|
|
||||||
|
```text
|
||||||
|
recompute=True: search_time=0.041s, size=5.9MB
|
||||||
|
recompute=False: search_time=0.013s, size=24.6MB
|
||||||
|
```
|
||||||
|
|
||||||
|
Conclusion:
|
||||||
|
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
|
||||||
|
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Further Reading
|
||||||
|
|
||||||
|
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
|
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||||
|
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||||
|
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# FAQ
|
||||||
|
|
||||||
|
## 1. My building time seems long
|
||||||
|
|
||||||
|
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
23
docs/features.md
Normal file
23
docs/features.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# ✨ Detailed Features
|
||||||
|
|
||||||
|
## 🔥 Core Features
|
||||||
|
|
||||||
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
||||||
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
|
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||||
|
|
||||||
|
## 🛠️ Technical Highlights
|
||||||
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
|
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||||
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||||
|
|
||||||
|
## 🎨 Developer Experience
|
||||||
|
|
||||||
|
- **Simple Python API** - Get started in minutes
|
||||||
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
149
docs/grep_search.md
Normal file
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# LEANN Grep Search Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("your_index_path")
|
||||||
|
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:100]}...")
|
||||||
|
print("-" * 40)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Comparison: Semantic vs Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Semantic search - finds conceptually similar content
|
||||||
|
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
|
||||||
|
# Grep search - finds exact text matches
|
||||||
|
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to Use Grep Search
|
||||||
|
|
||||||
|
### Use Cases
|
||||||
|
|
||||||
|
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||||
|
- **Error Debugging**: Locating exact error messages or stack traces
|
||||||
|
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find function definitions
|
||||||
|
functions = searcher.search("def __init__", use_grep=True)
|
||||||
|
|
||||||
|
# Find import statements
|
||||||
|
imports = searcher.search("from sklearn import", use_grep=True)
|
||||||
|
|
||||||
|
# Find specific error types
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
|
||||||
|
# Find TODO comments
|
||||||
|
todos = searcher.search("TODO:", use_grep=True)
|
||||||
|
|
||||||
|
# Find configuration entries
|
||||||
|
configs = searcher.search("server_port=", use_grep=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||||
|
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||||
|
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||||
|
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||||
|
|
||||||
|
### Search Process
|
||||||
|
|
||||||
|
```
|
||||||
|
Query: "def train_model"
|
||||||
|
↓
|
||||||
|
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||||
|
↓
|
||||||
|
Parse matching JSON lines
|
||||||
|
↓
|
||||||
|
Calculate scores based on term frequency
|
||||||
|
↓
|
||||||
|
Return top_k results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scoring Algorithm
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Term frequency in document
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
```
|
||||||
|
|
||||||
|
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### Grep Command Not Found
|
||||||
|
```
|
||||||
|
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**: Install grep on your system:
|
||||||
|
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||||
|
- **macOS**: grep is pre-installed
|
||||||
|
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||||
|
|
||||||
|
#### No Results Found
|
||||||
|
```python
|
||||||
|
# Check if your query exists in the raw data
|
||||||
|
results = searcher.search("your_query", use_grep=True)
|
||||||
|
if not results:
|
||||||
|
print("No exact matches found. Try:")
|
||||||
|
print("1. Check spelling and case")
|
||||||
|
print("2. Use partial terms")
|
||||||
|
print("3. Switch to semantic search")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
Demonstrates grep search for exact text matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
def demonstrate_grep_search():
|
||||||
|
# Initialize searcher
|
||||||
|
searcher = LeannSearcher("my_index")
|
||||||
|
|
||||||
|
print("=== Function Search ===")
|
||||||
|
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||||
|
for i, result in enumerate(functions, 1):
|
||||||
|
print(f"{i}. Score: {result.score}")
|
||||||
|
print(f" Preview: {result.text[:60]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=== Error Search ===")
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||||
|
for result in errors:
|
||||||
|
print(f"Content: {result.text.strip()}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demonstrate_grep_search()
|
||||||
|
```
|
||||||
300
docs/metadata_filtering.md
Normal file
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
# LEANN Metadata Filtering Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Adding Metadata to Your Documents
|
||||||
|
|
||||||
|
When building your index, add metadata to each text chunk:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
|
||||||
|
# Add text with metadata
|
||||||
|
builder.add_text(
|
||||||
|
text="Chapter 1: Alice falls down the rabbit hole",
|
||||||
|
metadata={
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"themes": ["adventure", "curiosity"],
|
||||||
|
"word_count": 150
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.build_index("alice_in_wonderland_index")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Searching with Metadata Filters
|
||||||
|
|
||||||
|
Use the `metadata_filters` parameter in search calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="What happens to Alice?",
|
||||||
|
top_k=10,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||||
|
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Filter Syntax
|
||||||
|
|
||||||
|
### Basic Structure
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"field_name": {"operator": value},
|
||||||
|
"another_field": {"operator": value}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Operators
|
||||||
|
|
||||||
|
#### Comparison Operators
|
||||||
|
- `"=="`: Equal to
|
||||||
|
- `"!="`: Not equal to
|
||||||
|
- `"<"`: Less than
|
||||||
|
- `"<="`: Less than or equal
|
||||||
|
- `">"`: Greater than
|
||||||
|
- `">="`: Greater than or equal
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||||
|
{"page": {">": 100}} # Pages after 100
|
||||||
|
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||||
|
{"word_count": {"<": 500}} # Short passages
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Membership Operators
|
||||||
|
- `"in"`: Value is in list
|
||||||
|
- `"not_in"`: Value is not in list
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||||
|
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||||
|
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||||
|
```
|
||||||
|
|
||||||
|
#### String Operators
|
||||||
|
- `"contains"`: String contains substring
|
||||||
|
- `"starts_with"`: String starts with prefix
|
||||||
|
- `"ends_with"`: String ends with suffix
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||||
|
{"filename": {"ends_with": ".py"}} # Python files
|
||||||
|
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Boolean Operators
|
||||||
|
- `"is_true"`: Field is truthy
|
||||||
|
- `"is_false"`: Field is falsy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"is_published": {"is_true": True}} # Published content
|
||||||
|
{"is_draft": {"is_false": False}} # Not drafts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Operators on Same Field
|
||||||
|
|
||||||
|
You can apply multiple operators to the same field (AND logic):
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"word_count": {
|
||||||
|
">=": 100, # At least 100 words
|
||||||
|
"<=": 500 # At most 500 words
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compound Filters
|
||||||
|
|
||||||
|
Multiple fields are combined with AND logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"chapter": {"<=": 10}, # Up to chapter 10
|
||||||
|
"character": {"==": "Alice"}, # About Alice
|
||||||
|
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Case Examples
|
||||||
|
|
||||||
|
### 1. Spoiler-Free Book Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Reader has only read up to chapter 5
|
||||||
|
def search_spoiler_free(query, max_chapter):
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": max_chapter},
|
||||||
|
"spoiler_level": {"in": ["none", "low"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Document Management by Date
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find recent documents
|
||||||
|
recent_docs = searcher.search(
|
||||||
|
query="project updates",
|
||||||
|
metadata_filters={
|
||||||
|
"date": {">=": "2024-01-01"},
|
||||||
|
"document_type": {"==": "report"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Code Search by File Type
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search only Python files
|
||||||
|
python_code = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Content Filtering by Audience
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Age-appropriate content
|
||||||
|
family_content = searcher.search(
|
||||||
|
query="adventure stories",
|
||||||
|
metadata_filters={
|
||||||
|
"age_rating": {"in": ["G", "PG"]},
|
||||||
|
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Multi-Book Series Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search across first 3 books only
|
||||||
|
early_series = searcher.search(
|
||||||
|
query="character development",
|
||||||
|
metadata_filters={
|
||||||
|
"series": {"==": "Harry Potter"},
|
||||||
|
"book_number": {"<=": 3}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the Example
|
||||||
|
|
||||||
|
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Don't forget to set up the environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Run the spoiler-free book RAG example
|
||||||
|
uv run examples/spoiler_free_book_rag.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This example demonstrates:
|
||||||
|
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||||
|
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||||
|
- Different scenarios for readers at various points in the book
|
||||||
|
|
||||||
|
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||||
|
|
||||||
|
## Advanced Patterns
|
||||||
|
|
||||||
|
### Custom Chunking with metadata
|
||||||
|
|
||||||
|
```python
|
||||||
|
def chunk_book_with_metadata(book_text, book_info):
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||||
|
# Extract entities, themes, etc.
|
||||||
|
characters = extract_characters(chapter_text)
|
||||||
|
themes = classify_themes(chapter_text)
|
||||||
|
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||||
|
|
||||||
|
# Create chunks with rich metadata
|
||||||
|
for paragraph in split_paragraphs(chapter_text):
|
||||||
|
chunks.append({
|
||||||
|
"text": paragraph,
|
||||||
|
"metadata": {
|
||||||
|
"book_title": book_info["title"],
|
||||||
|
"chapter": chapter_num,
|
||||||
|
"characters": characters,
|
||||||
|
"themes": themes,
|
||||||
|
"spoiler_level": spoiler_level,
|
||||||
|
"word_count": len(paragraph.split()),
|
||||||
|
"reading_level": calculate_reading_level(paragraph)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Efficient Filtering Strategies
|
||||||
|
|
||||||
|
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||||
|
|
||||||
|
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||||
|
|
||||||
|
### Best Practices
|
||||||
|
|
||||||
|
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||||
|
|
||||||
|
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||||
|
|
||||||
|
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||||
|
|
||||||
|
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||||
|
|
||||||
|
### Adding Metadata to Existing Indices
|
||||||
|
|
||||||
|
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Read existing passages and add metadata
|
||||||
|
def add_metadata_to_existing_chunks(chunks):
|
||||||
|
for chunk in chunks:
|
||||||
|
# Extract or assign metadata based on content
|
||||||
|
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Rebuild index with metadata
|
||||||
|
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
for chunk in enhanced_chunks:
|
||||||
|
builder.add_text(chunk["text"], chunk["metadata"])
|
||||||
|
builder.build_index("enhanced_index")
|
||||||
|
```
|
||||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Normalized Embeddings Support in LEANN
|
||||||
|
|
||||||
|
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||||
|
|
||||||
|
## What are Normalized Embeddings?
|
||||||
|
|
||||||
|
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||||
|
|
||||||
|
## Automatic Detection
|
||||||
|
|
||||||
|
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||||
|
|
||||||
|
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||||
|
2. **Show a warning** if you manually specify a different distance metric
|
||||||
|
3. **Provide optimal search performance** with the correct metric
|
||||||
|
|
||||||
|
## Supported Normalized Embedding Models
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
All OpenAI text embedding models are normalized:
|
||||||
|
- `text-embedding-ada-002`
|
||||||
|
- `text-embedding-3-small`
|
||||||
|
- `text-embedding-3-large`
|
||||||
|
|
||||||
|
### Voyage AI
|
||||||
|
All Voyage AI embedding models are normalized:
|
||||||
|
- `voyage-2`
|
||||||
|
- `voyage-3`
|
||||||
|
- `voyage-large-2`
|
||||||
|
- `voyage-multilingual-2`
|
||||||
|
- `voyage-code-2`
|
||||||
|
|
||||||
|
### Cohere
|
||||||
|
All Cohere embedding models are normalized:
|
||||||
|
- `embed-english-v3.0`
|
||||||
|
- `embed-multilingual-v3.0`
|
||||||
|
- `embed-english-light-v3.0`
|
||||||
|
- `embed-multilingual-light-v3.0`
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
# Automatic detection - will use cosine distance
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai"
|
||||||
|
)
|
||||||
|
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||||
|
# Automatically setting distance_metric='cosine'
|
||||||
|
|
||||||
|
# Manual override (not recommended)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
distance_metric="mips" # Will show warning
|
||||||
|
)
|
||||||
|
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Non-Normalized Embeddings
|
||||||
|
|
||||||
|
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||||
|
|
||||||
|
## Why This Matters
|
||||||
|
|
||||||
|
Using the wrong distance metric with normalized embeddings can lead to:
|
||||||
|
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||||
|
- **Incorrect ranking** of search results
|
||||||
|
- **Suboptimal performance** compared to using the correct metric
|
||||||
|
|
||||||
|
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# 📈 Roadmap
|
||||||
|
|
||||||
|
## 🎯 Q2 2025
|
||||||
|
|
||||||
|
- [X] HNSW backend integration
|
||||||
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
## 🚀 Q3 2025
|
||||||
|
|
||||||
|
- [ ] Advanced caching strategies
|
||||||
|
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||||
|
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||||
|
- [ ] Add OpenAI recompute API
|
||||||
|
|
||||||
|
## 🌟 Q4 2025
|
||||||
|
|
||||||
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
|
- [ ] Visual similarity search
|
||||||
|
- [ ] Query rewrtiting, rerank and expansion
|
||||||
0
examples/__init__.py
Normal file
0
examples/__init__.py
Normal file
@@ -1,16 +1,23 @@
|
|||||||
"""
|
"""
|
||||||
Simple demo showing basic leann usage
|
Simple demo showing basic leann usage
|
||||||
Run: uv run python examples/simple_demo.py
|
Run: uv run python examples/basic_demo.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from leann import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
|
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Simple demo of Leann with selectable embedding models.")
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-mpnet-base-v2",
|
description="Simple demo of Leann with selectable embedding models."
|
||||||
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.")
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||||
@@ -74,7 +81,7 @@ def main():
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
print("Demo completed! Try running:")
|
print("Demo completed! Try running:")
|
||||||
print(" uv run python examples/document_search.py")
|
print(" uv run python apps/document_rag.py")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Document search demo with recompute mode
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import shutil
|
|
||||||
import time
|
|
||||||
|
|
||||||
# Import backend packages to trigger plugin registration
|
|
||||||
try:
|
|
||||||
import leann_backend_diskann
|
|
||||||
import leann_backend_hnsw
|
|
||||||
print("INFO: Backend packages imported successfully.")
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"WARNING: Could not import backend packages. Error: {e}")
|
|
||||||
|
|
||||||
# Import upper-level API from leann-core
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
|
|
||||||
|
|
||||||
def load_sample_documents():
|
|
||||||
"""Create sample documents for demonstration"""
|
|
||||||
docs = [
|
|
||||||
{"title": "Intro to Python", "content": "Python is a high-level, interpreted language known for simplicity."},
|
|
||||||
{"title": "ML Basics", "content": "Machine learning builds systems that learn from data."},
|
|
||||||
{"title": "Data Structures", "content": "Data structures like arrays, lists, and graphs organize data."},
|
|
||||||
]
|
|
||||||
return docs
|
|
||||||
|
|
||||||
def main():
|
|
||||||
print("==========================================================")
|
|
||||||
print("=== Leann Document Search Demo (DiskANN + Recompute) ===")
|
|
||||||
print("==========================================================")
|
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_indices")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "documents.diskann")
|
|
||||||
BACKEND_TO_TEST = "diskann"
|
|
||||||
|
|
||||||
if INDEX_DIR.exists():
|
|
||||||
print(f"--- Cleaning up old index directory: {INDEX_DIR} ---")
|
|
||||||
shutil.rmtree(INDEX_DIR)
|
|
||||||
|
|
||||||
# --- 1. Build index ---
|
|
||||||
print(f"\n[PHASE 1] Building index using '{BACKEND_TO_TEST}' backend...")
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=BACKEND_TO_TEST,
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64
|
|
||||||
)
|
|
||||||
|
|
||||||
documents = load_sample_documents()
|
|
||||||
print(f"Loaded {len(documents)} sample documents.")
|
|
||||||
for doc in documents:
|
|
||||||
builder.add_text(doc["content"], metadata={"title": doc["title"]})
|
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"\nIndex built!")
|
|
||||||
|
|
||||||
# --- 2. Basic search demo ---
|
|
||||||
print(f"\n[PHASE 2] Basic search using '{BACKEND_TO_TEST}' backend...")
|
|
||||||
searcher = LeannSearcher(index_path=INDEX_PATH)
|
|
||||||
|
|
||||||
query = "What is machine learning?"
|
|
||||||
print(f"\nQuery: '{query}'")
|
|
||||||
|
|
||||||
print("\n--- Basic search mode (PQ computation) ---")
|
|
||||||
start_time = time.time()
|
|
||||||
results = searcher.search(query, top_k=2)
|
|
||||||
basic_time = time.time() - start_time
|
|
||||||
|
|
||||||
print(f"⏱️ Basic search time: {basic_time:.3f} seconds")
|
|
||||||
print(">>> Basic search results <<<")
|
|
||||||
for i, res in enumerate(results, 1):
|
|
||||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
|
||||||
|
|
||||||
# --- 3. Recompute search demo ---
|
|
||||||
print(f"\n[PHASE 3] Recompute search using embedding server...")
|
|
||||||
|
|
||||||
print("\n--- Recompute search mode (get real embeddings via network) ---")
|
|
||||||
|
|
||||||
# Configure recompute parameters
|
|
||||||
recompute_params = {
|
|
||||||
"recompute_beighbor_embeddings": True, # Enable network recomputation
|
|
||||||
"USE_DEFERRED_FETCH": False, # Don't use deferred fetch
|
|
||||||
"skip_search_reorder": True, # Skip search reordering
|
|
||||||
"dedup_node_dis": True, # Enable node distance deduplication
|
|
||||||
"prune_ratio": 0.1, # Pruning ratio 10%
|
|
||||||
"batch_recompute": False, # Don't use batch recomputation
|
|
||||||
"global_pruning": False, # Don't use global pruning
|
|
||||||
"zmq_port": 5555, # ZMQ port
|
|
||||||
"embedding_model": "sentence-transformers/all-mpnet-base-v2"
|
|
||||||
}
|
|
||||||
|
|
||||||
print("Recompute parameter configuration:")
|
|
||||||
for key, value in recompute_params.items():
|
|
||||||
print(f" {key}: {value}")
|
|
||||||
|
|
||||||
print(f"\n🔄 Executing Recompute search...")
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
recompute_results = searcher.search(query, top_k=2, **recompute_params)
|
|
||||||
recompute_time = time.time() - start_time
|
|
||||||
|
|
||||||
print(f"⏱️ Recompute search time: {recompute_time:.3f} seconds")
|
|
||||||
print(">>> Recompute search results <<<")
|
|
||||||
for i, res in enumerate(recompute_results, 1):
|
|
||||||
print(f" {i}. ID: {res.id}, Score: {res.score:.4f}, Text: '{res.text}', Metadata: {res.metadata}")
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
print(f"\n--- Result comparison ---")
|
|
||||||
print(f"Basic search time: {basic_time:.3f} seconds")
|
|
||||||
print(f"Recompute time: {recompute_time:.3f} seconds")
|
|
||||||
|
|
||||||
print("\nBasic search vs Recompute results:")
|
|
||||||
for i in range(min(len(results), len(recompute_results))):
|
|
||||||
basic_score = results[i].score
|
|
||||||
recompute_score = recompute_results[i].score
|
|
||||||
score_diff = abs(basic_score - recompute_score)
|
|
||||||
print(f" Position {i+1}: PQ={basic_score:.4f}, Recompute={recompute_score:.4f}, Difference={score_diff:.4f}")
|
|
||||||
|
|
||||||
if recompute_time > basic_time:
|
|
||||||
print(f"✅ Recompute mode working correctly (more accurate but slower)")
|
|
||||||
else:
|
|
||||||
print(f"ℹ️ Recompute time is unusually fast, network recomputation may not be enabled")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Recompute search failed: {e}")
|
|
||||||
print("This usually indicates an embedding server connection issue")
|
|
||||||
|
|
||||||
# --- 4. Chat demo ---
|
|
||||||
print(f"\n[PHASE 4] Starting chat session...")
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
|
||||||
chat_response = chat.ask(query)
|
|
||||||
print(f"You: {query}")
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
print("\n==========================================================")
|
|
||||||
print("✅ Demo finished successfully!")
|
|
||||||
print("==========================================================")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
429
examples/dynamic_update_no_recompute.py
Normal file
429
examples/dynamic_update_no_recompute.py
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
"""Dynamic HNSW update demo without compact storage.
|
||||||
|
|
||||||
|
This script reproduces the minimal scenario we used while debugging on-the-fly
|
||||||
|
recompute:
|
||||||
|
|
||||||
|
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
|
||||||
|
2. Print the top results with `recompute_embeddings=True`.
|
||||||
|
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
|
||||||
|
4. Run the same query again to show the newly inserted passages.
|
||||||
|
|
||||||
|
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
|
||||||
|
ZMQ activity)::
|
||||||
|
|
||||||
|
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
|
||||||
|
uv run -m examples.dynamic_update_no_recompute \
|
||||||
|
--index-path .leann/examples/leann-demo.leann
|
||||||
|
|
||||||
|
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
|
||||||
|
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
|
||||||
|
It issues the query "What's LEANN?" before and after the update to show how the
|
||||||
|
new passages become immediately searchable. The script uses the
|
||||||
|
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
|
||||||
|
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
|
||||||
|
freshly added passages are embedded locally just like the initial build.
|
||||||
|
|
||||||
|
To make storage comparisons easy, the script can also build a matching
|
||||||
|
``is_recompute=False`` baseline (enabled by default) and report the index size
|
||||||
|
delta after the update. Disable the baseline run with
|
||||||
|
``--skip-compare-no-recompute`` if you only need the recompute flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
DEFAULT_QUERY = "What's LEANN?"
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
REPO_ROOT / "data" / "PrideandPrejudice.txt",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path]) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
return [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
|
||||||
|
searcher = LeannSearcher(str(index_path))
|
||||||
|
try:
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(title: str, results: Iterable) -> None:
|
||||||
|
print(f"\n=== {title} ===")
|
||||||
|
res_list = list(results)
|
||||||
|
print(f"results count: {len(res_list)}")
|
||||||
|
print("passages:")
|
||||||
|
if not res_list:
|
||||||
|
print(" (no passages returned)")
|
||||||
|
for res in res_list:
|
||||||
|
snippet = res.text.replace("\n", " ")[:120]
|
||||||
|
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def update_index(
|
||||||
|
index_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
) -> None:
|
||||||
|
updater = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
for offset, passage in enumerate(paragraphs, start=start_id):
|
||||||
|
updater.add_text(passage, metadata={"id": str(offset)})
|
||||||
|
updater.update_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
"""Remove leftover index artifacts for a clean rebuild."""
|
||||||
|
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def index_file_size(index_path: Path) -> int:
|
||||||
|
"""Return the size of the primary .index file for the given index path."""
|
||||||
|
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
return index_file.stat().st_size if index_file.exists() else 0
|
||||||
|
|
||||||
|
|
||||||
|
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
if not meta_path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(meta_path.read_text())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run_workflow(
|
||||||
|
*,
|
||||||
|
label: str,
|
||||||
|
index_path: Path,
|
||||||
|
initial_paragraphs: list[str],
|
||||||
|
update_paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
skip_search: bool,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
prefix = f"[{label}] " if label else ""
|
||||||
|
|
||||||
|
ensure_index_dir(index_path)
|
||||||
|
cleanup_index_files(index_path)
|
||||||
|
|
||||||
|
print(f"{prefix}Building initial index...")
|
||||||
|
build_initial_index(
|
||||||
|
index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_size = index_file_size(index_path)
|
||||||
|
if not skip_search:
|
||||||
|
before_results = run_search(
|
||||||
|
index_path,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
before_results = None
|
||||||
|
|
||||||
|
print(f"\n{prefix}Updating index with additional passages...")
|
||||||
|
update_index(
|
||||||
|
index_path,
|
||||||
|
start_id=len(initial_paragraphs),
|
||||||
|
paragraphs=update_paragraphs,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not skip_search:
|
||||||
|
after_results = run_search(
|
||||||
|
index_path,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
after_results = None
|
||||||
|
updated_size = index_file_size(index_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"initial_size": initial_size,
|
||||||
|
"updated_size": updated_size,
|
||||||
|
"delta": updated_size - initial_size,
|
||||||
|
"before_results": before_results if not skip_search else None,
|
||||||
|
"after_results": after_results if not skip_search else None,
|
||||||
|
"metadata": load_metadata_snapshot(index_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
type=Path,
|
||||||
|
nargs="+",
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Initial document files (PDF/TXT) used to build the base index",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/examples/leann-demo.leann"),
|
||||||
|
help="Destination index path (default: .leann/examples/leann-demo.leann)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-count",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of chunks to use from the initial documents (default: 8)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
type=Path,
|
||||||
|
nargs="*",
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Additional documents to add during update (PDF/TXT)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-count",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of chunks to append from update documents (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-text",
|
||||||
|
type=str,
|
||||||
|
default=(
|
||||||
|
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
|
||||||
|
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
|
||||||
|
"on demand to keep disk usage minimal."
|
||||||
|
),
|
||||||
|
help="Fallback text to append if --update-files is omitted",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of results to show for each search (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_QUERY,
|
||||||
|
help="Query to run before/after the update",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--compare-no-recompute",
|
||||||
|
dest="compare_no_recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Also run a baseline with is_recompute=False and report its index growth.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-compare-no-recompute",
|
||||||
|
dest="compare_no_recompute",
|
||||||
|
action="store_false",
|
||||||
|
help="Skip building the no-recompute baseline.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-search",
|
||||||
|
dest="skip_search",
|
||||||
|
action="store_true",
|
||||||
|
help="Skip the search step.",
|
||||||
|
)
|
||||||
|
parser.set_defaults(compare_no_recompute=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_chunks = load_chunks_from_files(list(args.initial_files))
|
||||||
|
if not initial_chunks:
|
||||||
|
raise ValueError("No text chunks extracted from the initial files.")
|
||||||
|
|
||||||
|
initial = initial_chunks[: args.initial_count]
|
||||||
|
if not initial:
|
||||||
|
raise ValueError("Initial chunk set is empty after applying --initial-count.")
|
||||||
|
|
||||||
|
if args.update_files:
|
||||||
|
update_chunks = load_chunks_from_files(list(args.update_files))
|
||||||
|
if not update_chunks:
|
||||||
|
raise ValueError("No text chunks extracted from the update files.")
|
||||||
|
to_add = update_chunks[: args.update_count]
|
||||||
|
else:
|
||||||
|
if not args.update_text:
|
||||||
|
raise ValueError("Provide --update-files or --update-text for the update step.")
|
||||||
|
to_add = [args.update_text]
|
||||||
|
if not to_add:
|
||||||
|
raise ValueError("Update chunk set is empty after applying --update-count.")
|
||||||
|
|
||||||
|
recompute_stats = run_workflow(
|
||||||
|
label="recompute",
|
||||||
|
index_path=args.index_path,
|
||||||
|
initial_paragraphs=initial,
|
||||||
|
update_paragraphs=to_add,
|
||||||
|
model_name=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
is_recompute=True,
|
||||||
|
query=args.query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
skip_search=args.skip_search,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.skip_search:
|
||||||
|
print_results("initial search", recompute_stats["before_results"])
|
||||||
|
if not args.skip_search:
|
||||||
|
print_results("after update", recompute_stats["after_results"])
|
||||||
|
print(
|
||||||
|
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||||
|
f" (Δ {recompute_stats['delta']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recompute_stats["metadata"]:
|
||||||
|
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||||
|
print("[recompute] metadata snapshot:")
|
||||||
|
print(json.dumps(meta_view, indent=2))
|
||||||
|
|
||||||
|
if args.compare_no_recompute:
|
||||||
|
baseline_path = (
|
||||||
|
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
|
||||||
|
)
|
||||||
|
baseline_stats = run_workflow(
|
||||||
|
label="no-recompute",
|
||||||
|
index_path=baseline_path,
|
||||||
|
initial_paragraphs=initial,
|
||||||
|
update_paragraphs=to_add,
|
||||||
|
model_name=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
is_recompute=False,
|
||||||
|
query=args.query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
skip_search=args.skip_search,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
|
||||||
|
f" (Δ {baseline_stats['delta']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
after_texts = (
|
||||||
|
[res.text for res in recompute_stats["after_results"]] if not args.skip_search else None
|
||||||
|
)
|
||||||
|
baseline_after_texts = (
|
||||||
|
[res.text for res in baseline_stats["after_results"]] if not args.skip_search else None
|
||||||
|
)
|
||||||
|
if after_texts == baseline_after_texts:
|
||||||
|
print(
|
||||||
|
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[no-recompute] WARNING: search results differ from recompute baseline.")
|
||||||
|
|
||||||
|
if baseline_stats["metadata"]:
|
||||||
|
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||||
|
print("[no-recompute] metadata snapshot:")
|
||||||
|
print(json.dumps(meta_view, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import os
|
|
||||||
import email
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
def find_all_messages_directories(root: str = None) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Recursively find all 'Messages' directories under the given root.
|
|
||||||
Returns a list of Path objects.
|
|
||||||
"""
|
|
||||||
if root is None:
|
|
||||||
# Auto-detect user's mail path
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
root = os.path.join(home_dir, "Library", "Mail")
|
|
||||||
|
|
||||||
messages_dirs = []
|
|
||||||
for dirpath, dirnames, filenames in os.walk(root):
|
|
||||||
if os.path.basename(dirpath) == "Messages":
|
|
||||||
messages_dirs.append(Path(dirpath))
|
|
||||||
return messages_dirs
|
|
||||||
|
|
||||||
class EmlxReader(BaseReader):
|
|
||||||
"""
|
|
||||||
Apple Mail .emlx file reader with embedded metadata.
|
|
||||||
|
|
||||||
Reads individual .emlx files from Apple Mail's storage format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, include_html: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
include_html: Whether to include HTML content in the email body (default: False)
|
|
||||||
"""
|
|
||||||
self.include_html = include_html
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Load data from the input directory containing .emlx files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing .emlx files
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum amount of messages to read.
|
|
||||||
"""
|
|
||||||
docs: List[Document] = []
|
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
# Walk through the directory recursively
|
|
||||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
|
||||||
# Skip hidden directories
|
|
||||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
|
||||||
|
|
||||||
for filename in filenames:
|
|
||||||
if count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
if filename.endswith(".emlx"):
|
|
||||||
filepath = os.path.join(dirpath, filename)
|
|
||||||
try:
|
|
||||||
# Read the .emlx file
|
|
||||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# .emlx files have a length prefix followed by the email content
|
|
||||||
# The first line contains the length, followed by the email
|
|
||||||
lines = content.split('\n', 1)
|
|
||||||
if len(lines) >= 2:
|
|
||||||
email_content = lines[1]
|
|
||||||
|
|
||||||
# Parse the email using Python's email module
|
|
||||||
try:
|
|
||||||
msg = email.message_from_string(email_content)
|
|
||||||
|
|
||||||
# Extract email metadata
|
|
||||||
subject = msg.get('Subject', 'No Subject')
|
|
||||||
from_addr = msg.get('From', 'Unknown')
|
|
||||||
to_addr = msg.get('To', 'Unknown')
|
|
||||||
date = msg.get('Date', 'Unknown')
|
|
||||||
|
|
||||||
# Extract email body
|
|
||||||
body = ""
|
|
||||||
if msg.is_multipart():
|
|
||||||
for part in msg.walk():
|
|
||||||
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
|
||||||
if part.get_content_type() == "text/html" and not self.include_html:
|
|
||||||
continue
|
|
||||||
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
|
||||||
# break
|
|
||||||
else:
|
|
||||||
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
|
||||||
doc_content = f"""
|
|
||||||
[File]: {filename}
|
|
||||||
[From]: {from_addr}
|
|
||||||
[To]: {to_addr}
|
|
||||||
[Subject]: {subject}
|
|
||||||
[Date]: {date}
|
|
||||||
[EMAIL BODY Start]:
|
|
||||||
{body}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# No separate metadata - everything is in the text
|
|
||||||
doc = Document(text=doc_content, metadata={})
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error parsing email from {filepath}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading file {filepath}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Loaded {len(docs)} email documents")
|
|
||||||
return docs
|
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import argparse
|
|
||||||
try:
|
|
||||||
import dotenv
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
# python-dotenv is not installed; skip loading environment variables
|
|
||||||
dotenv = None
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
|
||||||
|
|
||||||
# Default Chrome profile path
|
|
||||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
|
||||||
"""
|
|
||||||
Create LEANN index from multiple Chrome profile data sources.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of history entries to process per profile
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
|
||||||
|
|
||||||
# Load documents using ChromeHistoryReader from history_data
|
|
||||||
from history_data.history import ChromeHistoryReader
|
|
||||||
reader = ChromeHistoryReader()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
# Process each Chrome profile directory
|
|
||||||
for i, profile_dir in enumerate(profile_dirs):
|
|
||||||
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
documents = reader.load_data(
|
|
||||||
chrome_profile_path=str(profile_dir),
|
|
||||||
max_count=max_count
|
|
||||||
)
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
|
|
||||||
# Check if we've reached the max count
|
|
||||||
if max_count > 0 and total_processed >= max_count:
|
|
||||||
print(f"Reached max count of {max_count} documents")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {profile_dir}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {profile_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
|
||||||
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
text = node.get_content()
|
|
||||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
|
||||||
all_texts.append(text)
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1 # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
|
|
||||||
"""
|
|
||||||
Create LEANN index from Chrome history data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of history entries to process
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from Chrome history data...")
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Load documents using ChromeHistoryReader from history_data
|
|
||||||
from history_data.history import ChromeHistoryReader
|
|
||||||
reader = ChromeHistoryReader()
|
|
||||||
|
|
||||||
documents = reader.load_data(
|
|
||||||
chrome_profile_path=profile_path,
|
|
||||||
max_count=max_count
|
|
||||||
)
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print("No documents loaded. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} history documents")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1 # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
|
||||||
"""
|
|
||||||
Query the LEANN index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to the LEANN index
|
|
||||||
query: The query string
|
|
||||||
"""
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=index_path)
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=10,
|
|
||||||
recompute_beighbor_embeddings=True,
|
|
||||||
complexity=32,
|
|
||||||
beam_width=1,
|
|
||||||
llm_config={
|
|
||||||
"type": "openai",
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
llm_kwargs={
|
|
||||||
"temperature": 0.0,
|
|
||||||
"max_tokens": 1000
|
|
||||||
}
|
|
||||||
)
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# Parse command line arguments
|
|
||||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
|
||||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
|
||||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
|
||||||
parser.add_argument('--index-dir', type=str, default="./all_google_new",
|
|
||||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
|
||||||
parser.add_argument('--max-entries', type=int, default=1000,
|
|
||||||
help='Maximum number of history entries to process (default: 1000)')
|
|
||||||
parser.add_argument('--query', type=str, default=None,
|
|
||||||
help='Single query to run (default: runs example queries)')
|
|
||||||
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
|
||||||
help='Automatically find all Chrome profiles (default: True)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
|
||||||
|
|
||||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
|
||||||
print(f"Max entries: {args.max_entries}")
|
|
||||||
|
|
||||||
# Find Chrome profile directories
|
|
||||||
from history_data.history import ChromeHistoryReader
|
|
||||||
|
|
||||||
if args.auto_find_profiles:
|
|
||||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
|
||||||
if not profile_dirs:
|
|
||||||
print("No Chrome profiles found automatically. Exiting.")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# Use single specified profile
|
|
||||||
profile_path = Path(args.chrome_profile)
|
|
||||||
if not profile_path.exists():
|
|
||||||
print(f"Chrome profile not found: {profile_path}")
|
|
||||||
return
|
|
||||||
profile_dirs = [profile_path]
|
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
|
||||||
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
|
||||||
|
|
||||||
if index_path:
|
|
||||||
if args.query:
|
|
||||||
# Run single query
|
|
||||||
await query_leann_index(index_path, args.query)
|
|
||||||
else:
|
|
||||||
# Example queries
|
|
||||||
queries = [
|
|
||||||
"What websites did I visit about machine learning?",
|
|
||||||
"Find my search history about programming"
|
|
||||||
]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print("\n" + "="*60)
|
|
||||||
await query_leann_index(index_path, query)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
35
examples/grep_search_example.py
Normal file
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
|
||||||
|
Shows how to use grep-based text search instead of semantic search.
|
||||||
|
Useful when you need exact text matches rather than meaning-based results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Load your index
|
||||||
|
searcher = LeannSearcher("my-documents.leann")
|
||||||
|
|
||||||
|
# Regular semantic search
|
||||||
|
print("=== Semantic Search ===")
|
||||||
|
results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score:.3f}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Grep-based search for exact text matches
|
||||||
|
print("=== Grep Search ===")
|
||||||
|
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Find specific error messages
|
||||||
|
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||||
|
|
||||||
|
# Search for function definitions
|
||||||
|
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||||
|
print(f"Found {len(func_results)} class definitions")
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
import dotenv
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
|
|
||||||
# Add the project root to Python path so we can import from examples
|
|
||||||
project_root = Path(__file__).parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
# Auto-detect user's mail path
|
|
||||||
def get_mail_path():
|
|
||||||
"""Get the mail path for the current user"""
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
return os.path.join(home_dir, "Library", "Mail")
|
|
||||||
|
|
||||||
# Default mail path for macOS
|
|
||||||
# DEFAULT_MAIL_PATH = "/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data"
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
|
||||||
"""
|
|
||||||
Create LEANN index from multiple mail data sources.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages_dirs: List of Path objects pointing to Messages directories
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of emails to process per directory
|
|
||||||
include_html: Whether to include HTML content in email processing
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from multiple mail data sources...")
|
|
||||||
|
|
||||||
# Load documents using EmlxReader from LEANN_email_reader
|
|
||||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
|
||||||
reader = EmlxReader(include_html=include_html)
|
|
||||||
# from email_data.email import EmlxMboxReader
|
|
||||||
# from pathlib import Path
|
|
||||||
# reader = EmlxMboxReader()
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
# Process each Messages directory
|
|
||||||
for i, messages_dir in enumerate(messages_dirs):
|
|
||||||
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
documents = reader.load_data(messages_dir)
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
|
|
||||||
# Check if we've reached the max count
|
|
||||||
if max_count > 0 and total_processed >= max_count:
|
|
||||||
print(f"Reached max count of {max_count} documents")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {messages_dir}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {messages_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories and starting to split them into chunks")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
text = node.get_content()
|
|
||||||
# text = '[subject] ' + doc.metadata["subject"] + '\n' + text
|
|
||||||
all_texts.append(text)
|
|
||||||
|
|
||||||
print(f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1 # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
def create_leann_index(mail_path: str, index_path: str = "mail_index.leann", max_count: int = 1000, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
|
||||||
"""
|
|
||||||
Create LEANN index from mail data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mail_path: Path to the mail directory
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of emails to process
|
|
||||||
include_html: Whether to include HTML content in email processing
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from mail data...")
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Load documents using EmlxReader from LEANN_email_reader
|
|
||||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
|
||||||
reader = EmlxReader(include_html=include_html)
|
|
||||||
# from email_data.email import EmlxMboxReader
|
|
||||||
# from pathlib import Path
|
|
||||||
# reader = EmlxMboxReader()
|
|
||||||
documents = reader.load_data(Path(mail_path))
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print("No documents loaded. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} email documents")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1 # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
|
||||||
"""
|
|
||||||
Query the LEANN index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to the LEANN index
|
|
||||||
query: The query string
|
|
||||||
"""
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=index_path,
|
|
||||||
llm_config={"type": "openai", "model": "gpt-4o"})
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
|
||||||
chat_response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=10,
|
|
||||||
recompute_beighbor_embeddings=True,
|
|
||||||
complexity=12,
|
|
||||||
beam_width=1,
|
|
||||||
|
|
||||||
)
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"Time taken: {end_time - start_time} seconds")
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# Parse command line arguments
|
|
||||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
|
||||||
# Remove --mail-path argument and auto-detect all Messages directories
|
|
||||||
# Remove DEFAULT_MAIL_PATH
|
|
||||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_debug",
|
|
||||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
|
||||||
parser.add_argument('--max-emails', type=int, default=1000,
|
|
||||||
help='Maximum number of emails to process (-1 means all)')
|
|
||||||
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
|
||||||
help='Single query to run (default: runs example queries)')
|
|
||||||
parser.add_argument('--include-html', action='store_true', default=False,
|
|
||||||
help='Include HTML content in email processing (default: False)')
|
|
||||||
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
|
||||||
help='Embedding model to use (default: facebook/contriever)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f"args: {args}")
|
|
||||||
|
|
||||||
# Automatically find all Messages directories under the current user's Mail directory
|
|
||||||
from examples.email_data.LEANN_email_reader import find_all_messages_directories
|
|
||||||
mail_path = get_mail_path()
|
|
||||||
print(f"Searching for email data in: {mail_path}")
|
|
||||||
messages_dirs = find_all_messages_directories(mail_path)
|
|
||||||
|
|
||||||
print('len(messages_dirs): ', len(messages_dirs))
|
|
||||||
|
|
||||||
|
|
||||||
if not messages_dirs:
|
|
||||||
print("No Messages directories found. Exiting.")
|
|
||||||
return
|
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
|
||||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
|
||||||
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
|
||||||
|
|
||||||
if index_path:
|
|
||||||
if args.query:
|
|
||||||
# Run single query
|
|
||||||
await query_leann_index(index_path, args.query)
|
|
||||||
else:
|
|
||||||
# Example queries
|
|
||||||
queries = [
|
|
||||||
"Hows Berkeley Graduate Student Instructor",
|
|
||||||
"how's the icloud related advertisement saying",
|
|
||||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
|
||||||
]
|
|
||||||
for query in queries:
|
|
||||||
print("\n" + "="*60)
|
|
||||||
await query_leann_index(index_path, query)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
|
|
||||||
# Add the project root to Python path so we can import from examples
|
|
||||||
project_root = Path(__file__).parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from llama_index.core import VectorStoreIndex, StorageContext
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
# --- EMBEDDING MODEL ---
|
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# --- END EMBEDDING MODEL ---
|
|
||||||
|
|
||||||
# Import EmlxReader from the new module
|
|
||||||
from examples.email_data.LEANN_email_reader import EmlxReader
|
|
||||||
|
|
||||||
def create_and_save_index(mail_path: str, save_dir: str = "mail_index_embedded", max_count: int = 1000, include_html: bool = False):
|
|
||||||
print("Creating index from mail data with embedded metadata...")
|
|
||||||
documents = EmlxReader(include_html=include_html).load_data(mail_path, max_count=max_count)
|
|
||||||
if not documents:
|
|
||||||
print("No documents loaded. Exiting.")
|
|
||||||
return None
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
|
||||||
# Use facebook/contriever as the embedder
|
|
||||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
|
||||||
# set on device
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
embed_model._model.to("cuda")
|
|
||||||
# set mps
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
embed_model._model.to("mps")
|
|
||||||
else:
|
|
||||||
embed_model._model.to("cpu")
|
|
||||||
index = VectorStoreIndex.from_documents(
|
|
||||||
documents,
|
|
||||||
transformations=[text_splitter],
|
|
||||||
embed_model=embed_model
|
|
||||||
)
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
index.storage_context.persist(persist_dir=save_dir)
|
|
||||||
print(f"Index saved to {save_dir}")
|
|
||||||
return index
|
|
||||||
|
|
||||||
def load_index(save_dir: str = "mail_index_embedded"):
|
|
||||||
try:
|
|
||||||
storage_context = StorageContext.from_defaults(persist_dir=save_dir)
|
|
||||||
index = VectorStoreIndex.from_vector_store(
|
|
||||||
storage_context.vector_store,
|
|
||||||
storage_context=storage_context
|
|
||||||
)
|
|
||||||
print(f"Index loaded from {save_dir}")
|
|
||||||
return index
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading index: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def query_index(index, query: str):
|
|
||||||
if index is None:
|
|
||||||
print("No index available for querying.")
|
|
||||||
return
|
|
||||||
query_engine = index.as_query_engine()
|
|
||||||
response = query_engine.query(query)
|
|
||||||
print(f"Query: {query}")
|
|
||||||
print(f"Response: {response}")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Parse command line arguments
|
|
||||||
parser = argparse.ArgumentParser(description='LlamaIndex Mail Reader - Create and query email index')
|
|
||||||
parser.add_argument('--mail-path', type=str,
|
|
||||||
default="/Users/yichuan/Library/Mail/V10/0FCA0879-FD8C-4B7E-83BF-FDDA930791C5/[Gmail].mbox/All Mail.mbox/78BA5BE1-8819-4F9A-9613-EB63772F1DD0/Data/9/Messages",
|
|
||||||
help='Path to mail data directory')
|
|
||||||
parser.add_argument('--save-dir', type=str, default="mail_index_embedded",
|
|
||||||
help='Directory to store the index (default: mail_index_embedded)')
|
|
||||||
parser.add_argument('--max-emails', type=int, default=10000,
|
|
||||||
help='Maximum number of emails to process')
|
|
||||||
parser.add_argument('--include-html', action='store_true', default=False,
|
|
||||||
help='Include HTML content in email processing (default: False)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
mail_path = args.mail_path
|
|
||||||
save_dir = args.save_dir
|
|
||||||
|
|
||||||
if os.path.exists(save_dir) and os.path.exists(os.path.join(save_dir, "vector_store.json")):
|
|
||||||
print("Loading existing index...")
|
|
||||||
index = load_index(save_dir)
|
|
||||||
else:
|
|
||||||
print("Creating new index...")
|
|
||||||
index = create_and_save_index(mail_path, save_dir, max_count=args.max_emails, include_html=args.include_html)
|
|
||||||
if index:
|
|
||||||
queries = [
|
|
||||||
"Hows Berkeley Graduate Student Instructor",
|
|
||||||
"how's the icloud related advertisement saying",
|
|
||||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
|
||||||
]
|
|
||||||
for query in queries:
|
|
||||||
print("\n" + "="*50)
|
|
||||||
query_index(index, query)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
import asyncio
|
|
||||||
import dotenv
|
|
||||||
from leann.api import LeannBuilder, LeannChat
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
async def main(args):
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Loading documents...")
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
args.data_dir,
|
|
||||||
recursive=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
|
||||||
).load_data(show_progress=True)
|
|
||||||
print("Documents loaded.")
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
print("--- Index directory not found, building new index ---")
|
|
||||||
|
|
||||||
print("\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1, # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
|
|
||||||
llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
|
||||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
|
||||||
llm_config = {"type": "openai", "model": "gpt-4o"}
|
|
||||||
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
|
||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
|
||||||
|
|
||||||
# query = (
|
|
||||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
|
||||||
# )
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run Leann Chat with various LLM backends."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--llm",
|
|
||||||
type=str,
|
|
||||||
default="hf",
|
|
||||||
choices=["simulated", "ollama", "hf", "openai"],
|
|
||||||
help="The LLM backend to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
type=str,
|
|
||||||
default="Qwen/Qwen3-0.6B",
|
|
||||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
default="http://localhost:11434",
|
|
||||||
help="The host for the Ollama API.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--index-dir",
|
|
||||||
type=str,
|
|
||||||
default="./test_doc_files",
|
|
||||||
help="Directory where the Leann index will be stored.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--data-dir",
|
|
||||||
type=str,
|
|
||||||
default="examples/data",
|
|
||||||
help="Directory containing documents to index (PDF, TXT, MD files).",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
asyncio.run(main(args))
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
|
||||||
# Define the path for our new MLX-based index
|
# Define the path for our new MLX-based index
|
||||||
INDEX_PATH = "./mlx_diskann_index/leann"
|
INDEX_PATH = "./mlx_diskann_index/leann"
|
||||||
@@ -38,7 +39,5 @@ chat = LeannChat(index_path=INDEX_PATH)
|
|||||||
# add query
|
# add query
|
||||||
query = "MLX is an array framework for machine learning on Apple silicon."
|
query = "MLX is an array framework for machine learning on Apple silicon."
|
||||||
print(f"Query: {query}")
|
print(f"Query: {query}")
|
||||||
response = chat.ask(
|
response = chat.ask(query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1)
|
||||||
query, top_k=3, recompute_beighbor_embeddings=True, complexity=3, beam_width=1
|
|
||||||
)
|
|
||||||
print(f"Response: {response}")
|
print(f"Response: {response}")
|
||||||
@@ -1,319 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Multi-Vector Aggregator for Fat Embeddings
|
|
||||||
==========================================
|
|
||||||
|
|
||||||
This module implements aggregation strategies for multi-vector embeddings,
|
|
||||||
similar to ColPali's approach where multiple patch vectors represent a single document.
|
|
||||||
|
|
||||||
Key features:
|
|
||||||
- MaxSim aggregation (take maximum similarity across patches)
|
|
||||||
- Voting-based aggregation (count patch matches)
|
|
||||||
- Weighted aggregation (attention-score weighted)
|
|
||||||
- Spatial clustering of matching patches
|
|
||||||
- Document-level result consolidation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from typing import List, Dict, Any, Tuple, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from collections import defaultdict
|
|
||||||
import json
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PatchResult:
|
|
||||||
"""Represents a single patch search result."""
|
|
||||||
patch_id: int
|
|
||||||
image_name: str
|
|
||||||
image_path: str
|
|
||||||
coordinates: Tuple[int, int, int, int] # (x1, y1, x2, y2)
|
|
||||||
score: float
|
|
||||||
attention_score: float
|
|
||||||
scale: float
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AggregatedResult:
|
|
||||||
"""Represents an aggregated document-level result."""
|
|
||||||
image_name: str
|
|
||||||
image_path: str
|
|
||||||
doc_score: float
|
|
||||||
patch_count: int
|
|
||||||
best_patch: PatchResult
|
|
||||||
all_patches: List[PatchResult]
|
|
||||||
aggregation_method: str
|
|
||||||
spatial_clusters: Optional[List[List[PatchResult]]] = None
|
|
||||||
|
|
||||||
class MultiVectorAggregator:
|
|
||||||
"""
|
|
||||||
Aggregates multiple patch-level results into document-level results.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
aggregation_method: str = "maxsim",
|
|
||||||
spatial_clustering: bool = True,
|
|
||||||
cluster_distance_threshold: float = 100.0):
|
|
||||||
"""
|
|
||||||
Initialize the aggregator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
aggregation_method: "maxsim", "voting", "weighted", or "mean"
|
|
||||||
spatial_clustering: Whether to cluster spatially close patches
|
|
||||||
cluster_distance_threshold: Distance threshold for spatial clustering
|
|
||||||
"""
|
|
||||||
self.aggregation_method = aggregation_method
|
|
||||||
self.spatial_clustering = spatial_clustering
|
|
||||||
self.cluster_distance_threshold = cluster_distance_threshold
|
|
||||||
|
|
||||||
def aggregate_results(self,
|
|
||||||
search_results: List[Dict[str, Any]],
|
|
||||||
top_k: int = 10) -> List[AggregatedResult]:
|
|
||||||
"""
|
|
||||||
Aggregate patch-level search results into document-level results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
search_results: List of search results from LeannSearcher
|
|
||||||
top_k: Number of top documents to return
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of aggregated document results
|
|
||||||
"""
|
|
||||||
# Group results by image
|
|
||||||
image_groups = defaultdict(list)
|
|
||||||
|
|
||||||
for result in search_results:
|
|
||||||
metadata = result.metadata
|
|
||||||
if "image_name" in metadata and "patch_id" in metadata:
|
|
||||||
patch_result = PatchResult(
|
|
||||||
patch_id=metadata["patch_id"],
|
|
||||||
image_name=metadata["image_name"],
|
|
||||||
image_path=metadata["image_path"],
|
|
||||||
coordinates=tuple(metadata["coordinates"]),
|
|
||||||
score=result.score,
|
|
||||||
attention_score=metadata.get("attention_score", 0.0),
|
|
||||||
scale=metadata.get("scale", 1.0),
|
|
||||||
metadata=metadata
|
|
||||||
)
|
|
||||||
image_groups[metadata["image_name"]].append(patch_result)
|
|
||||||
|
|
||||||
# Aggregate each image group
|
|
||||||
aggregated_results = []
|
|
||||||
for image_name, patches in image_groups.items():
|
|
||||||
if len(patches) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
agg_result = self._aggregate_image_patches(image_name, patches)
|
|
||||||
aggregated_results.append(agg_result)
|
|
||||||
|
|
||||||
# Sort by aggregated score and return top-k
|
|
||||||
aggregated_results.sort(key=lambda x: x.doc_score, reverse=True)
|
|
||||||
return aggregated_results[:top_k]
|
|
||||||
|
|
||||||
def _aggregate_image_patches(self, image_name: str, patches: List[PatchResult]) -> AggregatedResult:
|
|
||||||
"""Aggregate patches for a single image."""
|
|
||||||
|
|
||||||
if self.aggregation_method == "maxsim":
|
|
||||||
doc_score = max(patch.score for patch in patches)
|
|
||||||
best_patch = max(patches, key=lambda p: p.score)
|
|
||||||
|
|
||||||
elif self.aggregation_method == "voting":
|
|
||||||
# Count patches above threshold
|
|
||||||
threshold = np.percentile([p.score for p in patches], 75)
|
|
||||||
doc_score = sum(1 for patch in patches if patch.score >= threshold)
|
|
||||||
best_patch = max(patches, key=lambda p: p.score)
|
|
||||||
|
|
||||||
elif self.aggregation_method == "weighted":
|
|
||||||
# Weight by attention scores
|
|
||||||
total_weighted_score = sum(p.score * p.attention_score for p in patches)
|
|
||||||
total_weights = sum(p.attention_score for p in patches)
|
|
||||||
doc_score = total_weighted_score / max(total_weights, 1e-8)
|
|
||||||
best_patch = max(patches, key=lambda p: p.score * p.attention_score)
|
|
||||||
|
|
||||||
elif self.aggregation_method == "mean":
|
|
||||||
doc_score = np.mean([patch.score for patch in patches])
|
|
||||||
best_patch = max(patches, key=lambda p: p.score)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown aggregation method: {self.aggregation_method}")
|
|
||||||
|
|
||||||
# Spatial clustering if enabled
|
|
||||||
spatial_clusters = None
|
|
||||||
if self.spatial_clustering:
|
|
||||||
spatial_clusters = self._cluster_patches_spatially(patches)
|
|
||||||
|
|
||||||
return AggregatedResult(
|
|
||||||
image_name=image_name,
|
|
||||||
image_path=patches[0].image_path,
|
|
||||||
doc_score=float(doc_score),
|
|
||||||
patch_count=len(patches),
|
|
||||||
best_patch=best_patch,
|
|
||||||
all_patches=sorted(patches, key=lambda p: p.score, reverse=True),
|
|
||||||
aggregation_method=self.aggregation_method,
|
|
||||||
spatial_clusters=spatial_clusters
|
|
||||||
)
|
|
||||||
|
|
||||||
def _cluster_patches_spatially(self, patches: List[PatchResult]) -> List[List[PatchResult]]:
|
|
||||||
"""Cluster patches that are spatially close to each other."""
|
|
||||||
if len(patches) <= 1:
|
|
||||||
return [patches]
|
|
||||||
|
|
||||||
clusters = []
|
|
||||||
remaining_patches = patches.copy()
|
|
||||||
|
|
||||||
while remaining_patches:
|
|
||||||
# Start new cluster with highest scoring remaining patch
|
|
||||||
seed_patch = max(remaining_patches, key=lambda p: p.score)
|
|
||||||
current_cluster = [seed_patch]
|
|
||||||
remaining_patches.remove(seed_patch)
|
|
||||||
|
|
||||||
# Add nearby patches to cluster
|
|
||||||
added_to_cluster = True
|
|
||||||
while added_to_cluster:
|
|
||||||
added_to_cluster = False
|
|
||||||
for patch in remaining_patches.copy():
|
|
||||||
if self._is_patch_nearby(patch, current_cluster):
|
|
||||||
current_cluster.append(patch)
|
|
||||||
remaining_patches.remove(patch)
|
|
||||||
added_to_cluster = True
|
|
||||||
|
|
||||||
clusters.append(current_cluster)
|
|
||||||
|
|
||||||
return sorted(clusters, key=lambda cluster: max(p.score for p in cluster), reverse=True)
|
|
||||||
|
|
||||||
def _is_patch_nearby(self, patch: PatchResult, cluster: List[PatchResult]) -> bool:
|
|
||||||
"""Check if a patch is spatially close to any patch in the cluster."""
|
|
||||||
patch_center = self._get_patch_center(patch.coordinates)
|
|
||||||
|
|
||||||
for cluster_patch in cluster:
|
|
||||||
cluster_center = self._get_patch_center(cluster_patch.coordinates)
|
|
||||||
distance = np.sqrt((patch_center[0] - cluster_center[0])**2 +
|
|
||||||
(patch_center[1] - cluster_center[1])**2)
|
|
||||||
|
|
||||||
if distance <= self.cluster_distance_threshold:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _get_patch_center(self, coordinates: Tuple[int, int, int, int]) -> Tuple[float, float]:
|
|
||||||
"""Get center point of a patch."""
|
|
||||||
x1, y1, x2, y2 = coordinates
|
|
||||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
|
||||||
|
|
||||||
def print_aggregated_results(self, results: List[AggregatedResult], max_patches_per_doc: int = 3):
|
|
||||||
"""Pretty print aggregated results."""
|
|
||||||
print(f"\n🔍 Aggregated Results (method: {self.aggregation_method})")
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f"\n{i+1}. {result.image_name}")
|
|
||||||
print(f" Doc Score: {result.doc_score:.4f} | Patches: {result.patch_count}")
|
|
||||||
print(f" Path: {result.image_path}")
|
|
||||||
|
|
||||||
# Show best patch
|
|
||||||
best = result.best_patch
|
|
||||||
print(f" 🌟 Best Patch: #{best.patch_id} at {best.coordinates} (score: {best.score:.4f})")
|
|
||||||
|
|
||||||
# Show top patches
|
|
||||||
print(f" 📍 Top Patches:")
|
|
||||||
for j, patch in enumerate(result.all_patches[:max_patches_per_doc]):
|
|
||||||
print(f" {j+1}. Patch #{patch.patch_id}: {patch.score:.4f} at {patch.coordinates}")
|
|
||||||
|
|
||||||
# Show spatial clusters if available
|
|
||||||
if result.spatial_clusters and len(result.spatial_clusters) > 1:
|
|
||||||
print(f" 🗂️ Spatial Clusters: {len(result.spatial_clusters)}")
|
|
||||||
for j, cluster in enumerate(result.spatial_clusters[:2]): # Show top 2 clusters
|
|
||||||
cluster_score = max(p.score for p in cluster)
|
|
||||||
print(f" Cluster {j+1}: {len(cluster)} patches (best: {cluster_score:.4f})")
|
|
||||||
|
|
||||||
def demo_aggregation():
|
|
||||||
"""Demonstrate the multi-vector aggregation functionality."""
|
|
||||||
print("=== Multi-Vector Aggregation Demo ===")
|
|
||||||
|
|
||||||
# Simulate some patch-level search results
|
|
||||||
# In real usage, these would come from LeannSearcher.search()
|
|
||||||
|
|
||||||
class MockResult:
|
|
||||||
def __init__(self, score, metadata):
|
|
||||||
self.score = score
|
|
||||||
self.metadata = metadata
|
|
||||||
|
|
||||||
# Simulate results for 2 images with multiple patches each
|
|
||||||
mock_results = [
|
|
||||||
# Image 1: cats_and_kitchen.jpg - 4 patches
|
|
||||||
MockResult(0.85, {
|
|
||||||
"image_name": "cats_and_kitchen.jpg",
|
|
||||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
|
||||||
"patch_id": 3,
|
|
||||||
"coordinates": [100, 50, 224, 174], # Kitchen area
|
|
||||||
"attention_score": 0.92,
|
|
||||||
"scale": 1.0
|
|
||||||
}),
|
|
||||||
MockResult(0.78, {
|
|
||||||
"image_name": "cats_and_kitchen.jpg",
|
|
||||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
|
||||||
"patch_id": 7,
|
|
||||||
"coordinates": [200, 300, 324, 424], # Cat area
|
|
||||||
"attention_score": 0.88,
|
|
||||||
"scale": 1.0
|
|
||||||
}),
|
|
||||||
MockResult(0.72, {
|
|
||||||
"image_name": "cats_and_kitchen.jpg",
|
|
||||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
|
||||||
"patch_id": 12,
|
|
||||||
"coordinates": [150, 100, 274, 224], # Appliances
|
|
||||||
"attention_score": 0.75,
|
|
||||||
"scale": 1.0
|
|
||||||
}),
|
|
||||||
MockResult(0.65, {
|
|
||||||
"image_name": "cats_and_kitchen.jpg",
|
|
||||||
"image_path": "/path/to/cats_and_kitchen.jpg",
|
|
||||||
"patch_id": 15,
|
|
||||||
"coordinates": [50, 250, 174, 374], # Furniture
|
|
||||||
"attention_score": 0.70,
|
|
||||||
"scale": 1.0
|
|
||||||
}),
|
|
||||||
|
|
||||||
# Image 2: city_street.jpg - 3 patches
|
|
||||||
MockResult(0.68, {
|
|
||||||
"image_name": "city_street.jpg",
|
|
||||||
"image_path": "/path/to/city_street.jpg",
|
|
||||||
"patch_id": 2,
|
|
||||||
"coordinates": [300, 100, 424, 224], # Buildings
|
|
||||||
"attention_score": 0.80,
|
|
||||||
"scale": 1.0
|
|
||||||
}),
|
|
||||||
MockResult(0.62, {
|
|
||||||
"image_name": "city_street.jpg",
|
|
||||||
"image_path": "/path/to/city_street.jpg",
|
|
||||||
"patch_id": 8,
|
|
||||||
"coordinates": [100, 350, 224, 474], # Street level
|
|
||||||
"attention_score": 0.75,
|
|
||||||
"scale": 1.0
|
|
||||||
}),
|
|
||||||
MockResult(0.55, {
|
|
||||||
"image_name": "city_street.jpg",
|
|
||||||
"image_path": "/path/to/city_street.jpg",
|
|
||||||
"patch_id": 11,
|
|
||||||
"coordinates": [400, 200, 524, 324], # Sky area
|
|
||||||
"attention_score": 0.60,
|
|
||||||
"scale": 1.0
|
|
||||||
}),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Test different aggregation methods
|
|
||||||
methods = ["maxsim", "voting", "weighted", "mean"]
|
|
||||||
|
|
||||||
for method in methods:
|
|
||||||
print(f"\n{'='*20} {method.upper()} AGGREGATION {'='*20}")
|
|
||||||
|
|
||||||
aggregator = MultiVectorAggregator(
|
|
||||||
aggregation_method=method,
|
|
||||||
spatial_clustering=True,
|
|
||||||
cluster_distance_threshold=100.0
|
|
||||||
)
|
|
||||||
|
|
||||||
aggregated = aggregator.aggregate_results(mock_results, top_k=5)
|
|
||||||
aggregator.print_aggregated_results(aggregated)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
demo_aggregation()
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
OpenAI Embedding Example
|
|
||||||
|
|
||||||
Complete example showing how to build and search with OpenAI embeddings using HNSW backend.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import dotenv
|
|
||||||
from pathlib import Path
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
# Load environment variables
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# Check if OpenAI API key is available
|
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
print("ERROR: OPENAI_API_KEY environment variable not set")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"✅ OpenAI API key found: {api_key[:10]}...")
|
|
||||||
|
|
||||||
# Sample texts
|
|
||||||
sample_texts = [
|
|
||||||
"Machine learning is a powerful technology that enables computers to learn from data.",
|
|
||||||
"Natural language processing helps computers understand and generate human language.",
|
|
||||||
"Deep learning uses neural networks with multiple layers to solve complex problems.",
|
|
||||||
"Computer vision allows machines to interpret and understand visual information.",
|
|
||||||
"Reinforcement learning trains agents to make decisions through trial and error.",
|
|
||||||
"Data science combines statistics, math, and programming to extract insights from data.",
|
|
||||||
"Artificial intelligence aims to create machines that can perform human-like tasks.",
|
|
||||||
"Python is a popular programming language used extensively in data science and AI.",
|
|
||||||
"Neural networks are inspired by the structure and function of the human brain.",
|
|
||||||
"Big data refers to extremely large datasets that require special tools to process."
|
|
||||||
]
|
|
||||||
|
|
||||||
INDEX_DIR = Path("./simple_openai_test_index")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "simple_test.leann")
|
|
||||||
|
|
||||||
print(f"\n=== Building Index with OpenAI Embeddings ===")
|
|
||||||
print(f"Index path: {INDEX_PATH}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use proper configuration for OpenAI embeddings
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="text-embedding-3-small",
|
|
||||||
embedding_mode="openai",
|
|
||||||
# HNSW settings for OpenAI embeddings
|
|
||||||
M=16, # Smaller graph degree
|
|
||||||
efConstruction=64, # Smaller construction complexity
|
|
||||||
is_compact=True, # Enable compact storage for recompute
|
|
||||||
is_recompute=True, # MUST enable for OpenAI embeddings
|
|
||||||
num_threads=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(sample_texts)} texts to the index...")
|
|
||||||
for i, text in enumerate(sample_texts):
|
|
||||||
metadata = {"id": f"doc_{i}", "topic": "AI"}
|
|
||||||
builder.add_text(text, metadata)
|
|
||||||
|
|
||||||
print("Building index...")
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"✅ Index built successfully!")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error building index: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"\n=== Testing Search ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
searcher = LeannSearcher(INDEX_PATH)
|
|
||||||
|
|
||||||
test_queries = [
|
|
||||||
"What is machine learning?",
|
|
||||||
"How do neural networks work?",
|
|
||||||
"Programming languages for data science"
|
|
||||||
]
|
|
||||||
|
|
||||||
for query in test_queries:
|
|
||||||
print(f"\n🔍 Query: '{query}'")
|
|
||||||
results = searcher.search(query, top_k=3)
|
|
||||||
|
|
||||||
print(f" Found {len(results)} results:")
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
print(f" {i+1}. Score: {result.score:.4f}")
|
|
||||||
print(f" Text: {result.text[:80]}...")
|
|
||||||
|
|
||||||
print(f"\n✅ Search test completed successfully!")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error during search: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
if success:
|
|
||||||
print(f"\n🎉 Simple OpenAI index test completed successfully!")
|
|
||||||
else:
|
|
||||||
print(f"\n💥 Simple OpenAI index test failed!")
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from leann.api import LeannChat
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
INDEX_DIR = Path("./test_pdf_index_huawei")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH)
|
|
||||||
query = "What is the main idea of RL and give me 5 exapmle of classic RL algorithms?"
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
|
||||||
# query = "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
|
||||||
response = chat.ask(query,top_k=20,recompute_beighbor_embeddings=True,complexity=32,beam_width=1)
|
|
||||||
print(f"\n[PHASE 2] Response: {response}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
250
examples/spoiler_free_book_rag.py
Normal file
250
examples/spoiler_free_book_rag.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Spoiler-Free Book RAG Example using LEANN Metadata Filtering
|
||||||
|
|
||||||
|
This example demonstrates how to use LEANN's metadata filtering to create
|
||||||
|
a spoiler-free book RAG system where users can search for information
|
||||||
|
up to a specific chapter they've read.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python spoiler_free_book_rag.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
# Add LEANN to path (adjust path as needed)
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../packages/leann-core/src"))
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_book_with_metadata(book_title: str = "Sample Book") -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Create sample book chunks with metadata for demonstration.
|
||||||
|
|
||||||
|
In a real implementation, this would parse actual book files (epub, txt, etc.)
|
||||||
|
and extract chapter boundaries, character mentions, etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_title: Title of the book
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chunk dictionaries with text and metadata
|
||||||
|
"""
|
||||||
|
# Sample book chunks with metadata
|
||||||
|
# In practice, you'd use proper text processing libraries
|
||||||
|
|
||||||
|
sample_chunks = [
|
||||||
|
{
|
||||||
|
"text": "Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 1,
|
||||||
|
"characters": ["Alice", "Sister"],
|
||||||
|
"themes": ["boredom", "curiosity"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes ran close by her.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 1,
|
||||||
|
"page": 2,
|
||||||
|
"characters": ["Alice", "White Rabbit"],
|
||||||
|
"themes": ["decision", "surprise", "magic"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she fell to look about her and to wonder what was going to happen next.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 2,
|
||||||
|
"page": 15,
|
||||||
|
"characters": ["Alice"],
|
||||||
|
"themes": ["falling", "wonder", "transformation"],
|
||||||
|
"location": "rabbit hole",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice meets the Cheshire Cat, who tells her that everyone in Wonderland is mad, including Alice herself.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 6,
|
||||||
|
"page": 85,
|
||||||
|
"characters": ["Alice", "Cheshire Cat"],
|
||||||
|
"themes": ["madness", "philosophy", "identity"],
|
||||||
|
"location": "Duchess's house",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "At the Queen's croquet ground, Alice witnesses the absurd trial that reveals the arbitrary nature of Wonderland's justice system.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 8,
|
||||||
|
"page": 120,
|
||||||
|
"characters": ["Alice", "Queen of Hearts", "King of Hearts"],
|
||||||
|
"themes": ["justice", "absurdity", "authority"],
|
||||||
|
"location": "Queen's court",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"text": "Alice realizes that Wonderland was all a dream, even the Rabbit, as she wakes up on the riverbank next to her sister.",
|
||||||
|
"metadata": {
|
||||||
|
"book": book_title,
|
||||||
|
"chapter": 12,
|
||||||
|
"page": 180,
|
||||||
|
"characters": ["Alice", "Sister", "Rabbit"],
|
||||||
|
"themes": ["revelation", "reality", "growth"],
|
||||||
|
"location": "riverbank",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return sample_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def build_spoiler_free_index(book_chunks: list[dict[str, Any]], index_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a LEANN index with book chunks that include spoiler metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
book_chunks: List of book chunks with metadata
|
||||||
|
index_name: Name for the index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the built index
|
||||||
|
"""
|
||||||
|
print(f"📚 Building spoiler-free book index: {index_name}")
|
||||||
|
|
||||||
|
# Initialize LEANN builder
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw", embedding_model="text-embedding-3-small", embedding_mode="openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add each chunk with its metadata
|
||||||
|
for chunk in book_chunks:
|
||||||
|
builder.add_text(text=chunk["text"], metadata=chunk["metadata"])
|
||||||
|
|
||||||
|
# Build the index
|
||||||
|
index_path = f"{index_name}_book_index"
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
print(f"✅ Index built successfully: {index_path}")
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
|
||||||
|
def spoiler_free_search(
|
||||||
|
index_path: str,
|
||||||
|
query: str,
|
||||||
|
max_chapter: int,
|
||||||
|
character_filter: Optional[list[str]] = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Perform a spoiler-free search on the book index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_path: Path to the LEANN index
|
||||||
|
query: Search query
|
||||||
|
max_chapter: Maximum chapter number to include
|
||||||
|
character_filter: Optional list of characters to focus on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of search results safe for the reader
|
||||||
|
"""
|
||||||
|
print(f"🔍 Searching: '{query}' (up to chapter {max_chapter})")
|
||||||
|
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
metadata_filters = {"chapter": {"<=": max_chapter}}
|
||||||
|
|
||||||
|
if character_filter:
|
||||||
|
metadata_filters["characters"] = {"contains": character_filter[0]}
|
||||||
|
|
||||||
|
results = searcher.search(query=query, top_k=10, metadata_filters=metadata_filters)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def demo_spoiler_free_rag():
|
||||||
|
"""
|
||||||
|
Demonstrate the spoiler-free book RAG system.
|
||||||
|
"""
|
||||||
|
print("🎭 Spoiler-Free Book RAG Demo")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
# Step 1: Prepare book data
|
||||||
|
book_title = "Alice's Adventures in Wonderland"
|
||||||
|
book_chunks = chunk_book_with_metadata(book_title)
|
||||||
|
|
||||||
|
print(f"📖 Loaded {len(book_chunks)} chunks from '{book_title}'")
|
||||||
|
|
||||||
|
# Step 2: Build the index (in practice, this would be done once)
|
||||||
|
try:
|
||||||
|
index_path = build_spoiler_free_index(book_chunks, "alice_wonderland")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to build index (likely missing dependencies): {e}")
|
||||||
|
print(
|
||||||
|
"💡 This demo shows the filtering logic - actual indexing requires LEANN dependencies"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Demonstrate various spoiler-free searches
|
||||||
|
search_scenarios = [
|
||||||
|
{
|
||||||
|
"description": "Reader who has only read Chapter 1",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read up to Chapter 5",
|
||||||
|
"query": "Tell me about Alice's adventures",
|
||||||
|
"max_chapter": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read most of the book",
|
||||||
|
"query": "What does the Cheshire Cat represent?",
|
||||||
|
"max_chapter": 10,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "Reader who has read the whole book",
|
||||||
|
"query": "What can you tell me about the rabbit?",
|
||||||
|
"max_chapter": 12,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for scenario in search_scenarios:
|
||||||
|
print(f"\n📚 Scenario: {scenario['description']}")
|
||||||
|
print(f" Query: {scenario['query']}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = spoiler_free_search(
|
||||||
|
index_path=index_path,
|
||||||
|
query=scenario["query"],
|
||||||
|
max_chapter=scenario["max_chapter"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📄 Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results[:3], 1): # Show top 3
|
||||||
|
chapter = result.metadata.get("chapter", "?")
|
||||||
|
location = result.metadata.get("location", "?")
|
||||||
|
print(f" {i}. Chapter {chapter} ({location}): {result.text[:80]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Search failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("📚 LEANN Spoiler-Free Book RAG Example")
|
||||||
|
print("=====================================")
|
||||||
|
|
||||||
|
try:
|
||||||
|
demo_spoiler_free_rag()
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ Cannot run demo due to missing dependencies: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error running demo: {e}")
|
||||||
@@ -1,319 +0,0 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import dotenv
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any, Optional
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
import requests
|
|
||||||
import time
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
# Default WeChat export directory
|
|
||||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
|
||||||
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_wechat_exports(
|
|
||||||
export_dirs: List[Path],
|
|
||||||
index_path: str = "wechat_history_index.leann",
|
|
||||||
max_count: int = -1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create LEANN index from multiple WeChat export data sources.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_dirs: List of Path objects pointing to WeChat export directories
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of chat entries to process per export
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from multiple WeChat export data sources...")
|
|
||||||
|
|
||||||
# Load documents using WeChatHistoryReader from history_data
|
|
||||||
from history_data.wechat_history import WeChatHistoryReader
|
|
||||||
|
|
||||||
reader = WeChatHistoryReader()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
# Process each WeChat export directory
|
|
||||||
for i, export_dir in enumerate(export_dirs):
|
|
||||||
print(
|
|
||||||
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
documents = reader.load_data(
|
|
||||||
wechat_export_dir=str(export_dir),
|
|
||||||
max_count=max_count,
|
|
||||||
concatenate_messages=True, # Disable concatenation - one message per document
|
|
||||||
)
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
|
|
||||||
# Check if we've reached the max count
|
|
||||||
if max_count > 0 and total_processed >= max_count:
|
|
||||||
print(f"Reached max count of {max_count} documents")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {export_dir}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {export_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports and starting to split them into chunks"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
|
||||||
all_texts.append(text)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Finished splitting {len(all_documents)} documents into {len(all_texts)} text chunks"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1, # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
|
|
||||||
def create_leann_index(
|
|
||||||
export_dir: str = None,
|
|
||||||
index_path: str = "wechat_history_index.leann",
|
|
||||||
max_count: int = 1000,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create LEANN index from WeChat chat history data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_dir: Path to the WeChat export directory (optional, uses default if None)
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of chat entries to process
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from WeChat chat history data...")
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Load documents using WeChatHistoryReader from history_data
|
|
||||||
from history_data.wechat_history import WeChatHistoryReader
|
|
||||||
|
|
||||||
reader = WeChatHistoryReader()
|
|
||||||
|
|
||||||
documents = reader.load_data(
|
|
||||||
wechat_export_dir=export_dir,
|
|
||||||
max_count=max_count,
|
|
||||||
concatenate_messages=False, # Disable concatenation - one message per document
|
|
||||||
)
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print("No documents loaded. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} chat documents")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ", # MLX-optimized model
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1, # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
|
||||||
"""
|
|
||||||
Query the LEANN index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to the LEANN index
|
|
||||||
query: The query string
|
|
||||||
"""
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=index_path)
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=20,
|
|
||||||
recompute_beighbor_embeddings=True,
|
|
||||||
complexity=16,
|
|
||||||
beam_width=1,
|
|
||||||
llm_config={
|
|
||||||
"type": "openai",
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
|
||||||
)
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main function with integrated WeChat export functionality."""
|
|
||||||
|
|
||||||
# Parse command line arguments
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--export-dir",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_WECHAT_EXPORT_DIR,
|
|
||||||
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--index-dir",
|
|
||||||
type=str,
|
|
||||||
default="./wechat_history_magic_test_11Debug_new",
|
|
||||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-entries",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="Maximum number of chat entries to process (default: 5000)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--query",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Single query to run (default: runs example queries)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--force-export",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Force re-export of WeChat data even if exports exist",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
|
||||||
|
|
||||||
print(f"Using WeChat export directory: {args.export_dir}")
|
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
|
||||||
print(f"Max entries: {args.max_entries}")
|
|
||||||
|
|
||||||
# Initialize WeChat reader with export capabilities
|
|
||||||
from history_data.wechat_history import WeChatHistoryReader
|
|
||||||
|
|
||||||
reader = WeChatHistoryReader()
|
|
||||||
|
|
||||||
# Find existing exports or create new ones using the centralized method
|
|
||||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
|
||||||
if not export_dirs:
|
|
||||||
print("Failed to find or export WeChat data. Exiting.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
|
||||||
index_path = create_leann_index_from_multiple_wechat_exports(
|
|
||||||
export_dirs, INDEX_PATH, max_count=args.max_entries
|
|
||||||
)
|
|
||||||
|
|
||||||
if index_path:
|
|
||||||
if args.query:
|
|
||||||
# Run single query
|
|
||||||
await query_leann_index(index_path, args.query)
|
|
||||||
else:
|
|
||||||
# Example queries
|
|
||||||
queries = [
|
|
||||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
|
||||||
]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
await query_leann_index(index_path, query)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
28
llms.txt
Normal file
28
llms.txt
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# llms.txt — LEANN MCP and Agent Integration
|
||||||
|
product: LEANN
|
||||||
|
homepage: https://github.com/yichuan-w/LEANN
|
||||||
|
contact: https://github.com/yichuan-w/LEANN/issues
|
||||||
|
|
||||||
|
# Installation
|
||||||
|
install: uv tool install leann-core --with leann
|
||||||
|
|
||||||
|
# MCP Server Entry Point
|
||||||
|
mcp.server: leann_mcp
|
||||||
|
mcp.protocol_version: 2024-11-05
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
mcp.tools: leann_list, leann_search
|
||||||
|
|
||||||
|
mcp.tool.leann_list.description: List available LEANN indexes
|
||||||
|
mcp.tool.leann_list.input: {}
|
||||||
|
|
||||||
|
mcp.tool.leann_search.description: Semantic search across a named LEANN index
|
||||||
|
mcp.tool.leann_search.input.index_name: string, required
|
||||||
|
mcp.tool.leann_search.input.query: string, required
|
||||||
|
mcp.tool.leann_search.input.top_k: integer, optional, default=5, min=1, max=20
|
||||||
|
mcp.tool.leann_search.input.complexity: integer, optional, default=32, min=16, max=128
|
||||||
|
|
||||||
|
# Notes
|
||||||
|
note: Build indexes with `leann build <name> --docs <files...>` before searching.
|
||||||
|
example.add: claude mcp add --scope user leann-server -- leann_mcp
|
||||||
|
example.verify: claude mcp list | cat
|
||||||
@@ -1 +0,0 @@
|
|||||||
|
|
||||||
|
|||||||
1
packages/astchunk-leann
Submodule
1
packages/astchunk-leann
Submodule
Submodule packages/astchunk-leann added at ad9afa07b9
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user