Compare commits
261 Commits
apps
...
financeben
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e9e2f3da0 | ||
|
|
ed167f43b0 | ||
|
|
f9746d3fe2 | ||
|
|
a090a3444a | ||
|
|
aaaba27a4f | ||
|
|
f40f539456 | ||
|
|
576a2dcb49 | ||
|
|
ad8ab84675 | ||
|
|
58b96b64d8 | ||
|
|
a76c3cdac4 | ||
|
|
520619deab | ||
|
|
dea08c95b4 | ||
|
|
ec889f7ef4 | ||
|
|
322e5c162d | ||
|
|
edde0cdeb2 | ||
|
|
db7ba27ff6 | ||
|
|
5f7806e16f | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3357d5765e | ||
|
|
9dbd0c64cc | ||
|
|
9c400acd7e | ||
|
|
ac560964f5 | ||
|
|
07e4f176e1 | ||
|
|
b1daf021e0 | ||
|
|
3578680cb6 | ||
|
|
a0d6857faa | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
d7011bbea0 | ||
|
|
ef4c69d128 | ||
|
|
75c8aeee5f | ||
|
|
3d79741f9c | ||
|
|
df34c84bd3 | ||
|
|
8dfd2f015c | ||
|
|
ed72232bab | ||
|
|
26d961bfc5 | ||
|
|
722bda4ebb | ||
|
|
a7c7e8801d | ||
|
|
069bce558b | ||
|
|
4e5b73ce7b | ||
|
|
772894012e | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
5c163737c4 | ||
|
|
6d1d67ead7 | ||
|
|
a6f07a54f1 | ||
|
|
ed27ea6990 | ||
|
|
baf2d76e0e | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 | ||
|
|
ef980d70b3 | ||
|
|
db3c63c441 | ||
|
|
00eeadb9dd | ||
|
|
42c8370709 | ||
|
|
fafdf8fcbe | ||
|
|
21f7d8e031 | ||
|
|
46565b9249 | ||
|
|
3dad76126a | ||
|
|
18e28bda32 | ||
|
|
609fa62fd5 | ||
|
|
eab13434ef | ||
|
|
b2390ccc14 | ||
|
|
e8fca2c84a | ||
|
|
790ae14f69 | ||
|
|
ac363072e6 | ||
|
|
93465af46c | ||
|
|
792ece67dc | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
3ff5aac8e0 | ||
|
|
67fef60466 | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
5567302316 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c96d653072 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 | ||
|
|
cf2ef48967 | ||
|
|
0692bbf7a2 | ||
|
|
52584a171f | ||
|
|
efd6b5324b | ||
|
|
2baaa4549b | ||
|
|
35310ddd52 | ||
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb | ||
|
|
faf5ae3533 | ||
|
|
a44dccecac | ||
|
|
9cf9358b9c | ||
|
|
de252fef31 | ||
|
|
9076bc27b8 | ||
|
|
50686c0819 | ||
|
|
1614203786 | ||
|
|
3d4c75a56c | ||
|
|
2684ee71dc | ||
|
|
1d321953ba | ||
|
|
b3cb251369 | ||
|
|
0a17d2c9d8 | ||
|
|
e3defbca84 | ||
|
|
e407f63977 | ||
|
|
7add391b2c | ||
|
|
efd6373b32 | ||
|
|
d502fa24b0 | ||
|
|
258a9a5c7f | ||
|
|
5d41ac6115 | ||
|
|
2a0fdb49b8 | ||
|
|
9d1b7231b6 | ||
|
|
ed3095b478 | ||
|
|
88eca75917 | ||
|
|
42de27e16a | ||
|
|
c083bda5b7 | ||
|
|
e86da38726 | ||
|
|
99076e38bc | ||
|
|
9698c1a02c | ||
|
|
851f0f04c3 | ||
|
|
ae16d9d888 | ||
|
|
6e1af2eb0c | ||
|
|
7695dd0d50 | ||
|
|
c2065473ad | ||
|
|
5f3870564d | ||
|
|
c214b2e33e | ||
|
|
2420c5fd35 | ||
|
|
f48f526f0a | ||
|
|
5dd74982ba | ||
|
|
e07aaf52a7 | ||
|
|
30e5f12616 | ||
|
|
594427bf87 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
|||||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report a bug in LEANN
|
||||||
|
labels: ["bug"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: A clear description of the bug
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to reproduce
|
||||||
|
placeholder: |
|
||||||
|
1. Install with...
|
||||||
|
2. Run command...
|
||||||
|
3. See error
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: error
|
||||||
|
attributes:
|
||||||
|
label: Error message
|
||||||
|
description: Paste any error messages
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: version
|
||||||
|
attributes:
|
||||||
|
label: LEANN Version
|
||||||
|
placeholder: "0.1.0"
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Docker
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Documentation
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||||
|
about: Read the docs first
|
||||||
|
- name: Discussions
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||||
|
about: Ask questions and share ideas
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a new feature for LEANN
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: What problem does this solve?
|
||||||
|
description: Describe the problem or need
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: How would you like this to work?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: example
|
||||||
|
attributes:
|
||||||
|
label: Example usage
|
||||||
|
description: Show how the API might look
|
||||||
|
render: python
|
||||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!-- Brief description of your changes -->
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Tests pass (`uv run pytest`)
|
||||||
|
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||||
|
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||||
12
.github/workflows/build-and-publish.yml
vendored
Normal file
12
.github/workflows/build-and-publish.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
451
.github/workflows/build-reusable.yml
vendored
Normal file
451
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
name: Reusable Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: 'Git ref to build'
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint and Format Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install uv and Python
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Run pre-commit with only lint group (no project deps)
|
||||||
|
run: |
|
||||||
|
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||||
|
|
||||||
|
|
||||||
|
build:
|
||||||
|
needs: lint
|
||||||
|
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.13'
|
||||||
|
# ARM64 Linux builds
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.12'
|
||||||
|
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||||
|
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install uv and Python
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
patchelf
|
||||||
|
|
||||||
|
# Debug: Show system information
|
||||||
|
echo "🔍 System Information:"
|
||||||
|
echo "Architecture: $(uname -m)"
|
||||||
|
echo "OS: $(uname -a)"
|
||||||
|
echo "CPU info: $(lscpu | head -5)"
|
||||||
|
|
||||||
|
# Install math library based on architecture
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||||
|
|
||||||
|
if [[ "$ARCH" == "x86_64" ]]; then
|
||||||
|
# Install Intel MKL for DiskANN on x86_64
|
||||||
|
echo "📦 Installing Intel MKL for x86_64..."
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||||
|
echo "✅ Intel MKL installed for x86_64"
|
||||||
|
|
||||||
|
# Debug: Check MKL installation
|
||||||
|
echo "🔍 MKL Installation Check:"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||||
|
|
||||||
|
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||||
|
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||||
|
echo "📦 Installing OpenBLAS for ARM64..."
|
||||||
|
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||||
|
echo "✅ OpenBLAS installed for ARM64"
|
||||||
|
|
||||||
|
# Debug: Check OpenBLAS installation
|
||||||
|
echo "🔍 OpenBLAS Installation Check:"
|
||||||
|
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||||
|
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Debug: Show final library paths
|
||||||
|
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Don't install LLVM, use system clang for better compatibility
|
||||||
|
brew install libomp boost protobuf zeromq
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv python install ${{ matrix.python }}
|
||||||
|
uv venv --python ${{ matrix.python }} .uv-build
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
BUILD_PY=".uv-build\\Scripts\\python.exe"
|
||||||
|
else
|
||||||
|
BUILD_PY=".uv-build/bin/python"
|
||||||
|
fi
|
||||||
|
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --python "$BUILD_PY" auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --python "$BUILD_PY" delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
|
||||||
|
else
|
||||||
|
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Set macOS environment variables
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Use brew --prefix to automatically detect Homebrew installation path
|
||||||
|
HOMEBREW_PREFIX=$(brew --prefix)
|
||||||
|
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||||
|
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
||||||
|
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# Set compiler flags for OpenMP (required for both backends)
|
||||||
|
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
||||||
|
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Build packages
|
||||||
|
run: |
|
||||||
|
# Build core (platform independent)
|
||||||
|
cd packages/leann-core
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build HNSW backend
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||||
|
# Use system clang for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# Homebrew libraries on each macOS version require matching minimum version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
else
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build DiskANN backend
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||||
|
# Use system clang for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
|
# But Homebrew libraries on each macOS version require matching minimum version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
else
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build meta package (platform independent)
|
||||||
|
cd packages/leann
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Determine deployment target based on runner OS
|
||||||
|
# Must match the Homebrew libraries for each macOS version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
HNSW_TARGET="13.0"
|
||||||
|
DISKANN_TARGET="13.3"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
HNSW_TARGET="14.0"
|
||||||
|
DISKANN_TARGET="14.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
HNSW_TARGET="15.0"
|
||||||
|
DISKANN_TARGET="15.0"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: List built packages
|
||||||
|
run: |
|
||||||
|
echo "📦 Built packages:"
|
||||||
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
|
||||||
|
- name: Install built packages for testing
|
||||||
|
run: |
|
||||||
|
# Create uv-managed virtual environment with the requested interpreter
|
||||||
|
uv python install ${{ matrix.python }}
|
||||||
|
uv venv --python ${{ matrix.python }}
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
UV_PY=".venv\\Scripts\\python.exe"
|
||||||
|
else
|
||||||
|
UV_PY=".venv/bin/python"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install test dependency group only (avoids reinstalling project package)
|
||||||
|
uv pip install --python "$UV_PY" --group test
|
||||||
|
|
||||||
|
# Install core wheel built in this job
|
||||||
|
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||||
|
if [[ -n "$CORE_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$CORE_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||||
|
if [[ -z "$HNSW_WHL" ]]; then
|
||||||
|
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||||
|
fi
|
||||||
|
if [[ -n "$HNSW_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$HNSW_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
|
||||||
|
fi
|
||||||
|
|
||||||
|
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||||
|
if [[ -z "$DISKANN_WHL" ]]; then
|
||||||
|
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||||
|
fi
|
||||||
|
if [[ -n "$DISKANN_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$DISKANN_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
|
||||||
|
fi
|
||||||
|
|
||||||
|
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||||
|
if [[ -n "$LEANN_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$LEANN_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
env:
|
||||||
|
CI: true
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
|
TOKENIZERS_PARALLELISM: false
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
pytest tests/ -v --tb=short
|
||||||
|
|
||||||
|
- name: Run sanity checks (optional)
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run distance function tests if available
|
||||||
|
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||||
|
echo "Running distance function sanity checks..."
|
||||||
|
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
|
path: packages/*/dist/
|
||||||
|
|
||||||
|
|
||||||
|
arch-smoke:
|
||||||
|
name: Arch Linux smoke test (install & import)
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: archlinux:latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Prepare system
|
||||||
|
run: |
|
||||||
|
pacman -Syu --noconfirm
|
||||||
|
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||||
|
|
||||||
|
- name: Download ALL wheel artifacts from this run
|
||||||
|
uses: actions/download-artifact@v5
|
||||||
|
with:
|
||||||
|
# Don't specify name, download all artifacts
|
||||||
|
path: ./wheels
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Create virtual environment and install wheels
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
uv pip install --find-links wheels leann-core
|
||||||
|
uv pip install --find-links wheels leann-backend-hnsw
|
||||||
|
uv pip install --find-links wheels leann-backend-diskann
|
||||||
|
uv pip install --find-links wheels leann
|
||||||
|
|
||||||
|
- name: Import & tiny runtime check
|
||||||
|
env:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
python - <<'PY'
|
||||||
|
import leann
|
||||||
|
import leann_backend_hnsw as h
|
||||||
|
import leann_backend_diskann as d
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
b = LeannBuilder(backend_name="hnsw")
|
||||||
|
b.add_text("hello arch")
|
||||||
|
b.build_index("arch_demo.leann")
|
||||||
|
s = LeannSearcher("arch_demo.leann")
|
||||||
|
print("search:", s.search("hello", top_k=1))
|
||||||
|
PY
|
||||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
name: Link Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 3 * * 1"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
link-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: lycheeverse/lychee-action@v2
|
||||||
|
with:
|
||||||
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' README.md docs/ apps/ examples/ benchmarks/
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
129
.github/workflows/release-manual.yml
vendored
Normal file
129
.github/workflows/release-manual.yml
vendored
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
name: Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
version:
|
||||||
|
description: 'Version to release (e.g., 0.1.2)'
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-version:
|
||||||
|
name: Update Version
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
outputs:
|
||||||
|
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Validate version
|
||||||
|
run: |
|
||||||
|
# Remove 'v' prefix if present for validation
|
||||||
|
VERSION_CLEAN="${{ inputs.version }}"
|
||||||
|
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||||
|
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||||
|
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
- name: Update versions and push
|
||||||
|
id: push
|
||||||
|
run: |
|
||||||
|
# Check current version
|
||||||
|
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||||
|
echo "Current version: $CURRENT_VERSION"
|
||||||
|
echo "Target version: ${{ inputs.version }}"
|
||||||
|
|
||||||
|
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||||
|
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
else
|
||||||
|
./scripts/bump_version.sh ${{ inputs.version }}
|
||||||
|
git config user.name "GitHub Actions"
|
||||||
|
git config user.email "actions@github.com"
|
||||||
|
git add packages/*/pyproject.toml
|
||||||
|
git commit -m "chore: release v${{ inputs.version }}"
|
||||||
|
git push origin main
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
build-packages:
|
||||||
|
name: Build packages
|
||||||
|
needs: update-version
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish and Release
|
||||||
|
needs: [update-version, build-packages]
|
||||||
|
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: dist-artifacts
|
||||||
|
|
||||||
|
- name: Collect packages
|
||||||
|
run: |
|
||||||
|
mkdir -p dist
|
||||||
|
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||||
|
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||||
|
|
||||||
|
echo "📦 Packages to publish:"
|
||||||
|
ls -la dist/
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
env:
|
||||||
|
TWINE_USERNAME: __token__
|
||||||
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
|
echo "❌ PYPI_API_TOKEN not configured!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
pip install twine
|
||||||
|
twine upload dist/* --skip-existing --verbose
|
||||||
|
|
||||||
|
echo "✅ Published to PyPI!"
|
||||||
|
|
||||||
|
- name: Create release
|
||||||
|
run: |
|
||||||
|
# Check if tag already exists
|
||||||
|
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||||
|
else
|
||||||
|
git tag "v${{ inputs.version }}"
|
||||||
|
git push origin "v${{ inputs.version }}"
|
||||||
|
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if release already exists
|
||||||
|
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
|
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||||
|
else
|
||||||
|
gh release create "v${{ inputs.version }}" \
|
||||||
|
--title "Release v${{ inputs.version }}" \
|
||||||
|
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||||
|
--latest
|
||||||
|
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
24
.gitignore
vendored
24
.gitignore
vendored
@@ -12,16 +12,18 @@ outputs/
|
|||||||
*.idx
|
*.idx
|
||||||
*.map
|
*.map
|
||||||
.history/
|
.history/
|
||||||
scripts/
|
|
||||||
lm_eval.egg-info/
|
lm_eval.egg-info/
|
||||||
demo/experiment_results/**/*.json
|
demo/experiment_results/**/*.json
|
||||||
*.jsonl
|
*.jsonl
|
||||||
*.eml
|
*.eml
|
||||||
*.emlx
|
*.emlx
|
||||||
*.json
|
*.json
|
||||||
|
*.png
|
||||||
|
!.vscode/*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
|
!llms.txt
|
||||||
latency_breakdown*.json
|
latency_breakdown*.json
|
||||||
experiment_results/eval_results/diskann/*.json
|
experiment_results/eval_results/diskann/*.json
|
||||||
aws/
|
aws/
|
||||||
@@ -35,11 +37,15 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
examples/data/*
|
data/*
|
||||||
!examples/data/2501.14312v1 (1).pdf
|
!data/2501.14312v1 (1).pdf
|
||||||
!examples/data/2506.08276v1.pdf
|
!data/2506.08276v1.pdf
|
||||||
!examples/data/PrideandPrejudice.txt
|
!data/PrideandPrejudice.txt
|
||||||
!examples/data/README.md
|
!data/huawei_pangu.md
|
||||||
|
!data/ground_truth/
|
||||||
|
!data/indices/
|
||||||
|
!data/queries/
|
||||||
|
!data/.gitattributes
|
||||||
*.qdstrm
|
*.qdstrm
|
||||||
benchmark_results/
|
benchmark_results/
|
||||||
results/
|
results/
|
||||||
@@ -87,3 +93,9 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
*.passages.json
|
*.passages.json
|
||||||
|
|
||||||
batchtest.py
|
batchtest.py
|
||||||
|
tests/__pytest_cache__/
|
||||||
|
tests/__pycache__/
|
||||||
|
benchmarks/data/
|
||||||
|
|
||||||
|
## multi vector
|
||||||
|
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
url = https://github.com/zeromq/libzmq.git
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
[submodule "packages/astchunk-leann"]
|
||||||
|
path = packages/astchunk-leann
|
||||||
|
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||||
|
|||||||
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: debug-statements
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
|
- id: ruff-format
|
||||||
5
.vscode/extensions.json
vendored
Normal file
5
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"recommendations": [
|
||||||
|
"charliermarsh.ruff",
|
||||||
|
]
|
||||||
|
}
|
||||||
22
.vscode/settings.json
vendored
Normal file
22
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||||
|
"python.terminal.activateEnvironment": true,
|
||||||
|
"[python]": {
|
||||||
|
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||||
|
"editor.formatOnSave": true,
|
||||||
|
"editor.codeActionsOnSave": {
|
||||||
|
"source.organizeImports": "explicit",
|
||||||
|
"source.fixAll": "explicit"
|
||||||
|
},
|
||||||
|
"editor.insertSpaces": true,
|
||||||
|
"editor.tabSize": 4
|
||||||
|
},
|
||||||
|
"ruff.enable": true,
|
||||||
|
"files.watcherExclude": {
|
||||||
|
"**/.venv/**": true,
|
||||||
|
"**/__pycache__/**": true,
|
||||||
|
"**/*.egg-info/**": true,
|
||||||
|
"**/build/**": true,
|
||||||
|
"**/dist/**": true
|
||||||
|
}
|
||||||
|
}
|
||||||
387
apps/base_rag_example.py
Normal file
387
apps/base_rag_example.py
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
"""
|
||||||
|
Base class for unified RAG examples interface.
|
||||||
|
Provides common parameters and functionality for all RAG examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRAGExample(ABC):
|
||||||
|
"""Base class for all RAG examples with unified interface."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
default_index_name: str,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.default_index_name = default_index_name
|
||||||
|
self.parser = self._create_parser()
|
||||||
|
|
||||||
|
def _create_parser(self) -> argparse.ArgumentParser:
|
||||||
|
"""Create argument parser with common parameters."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Core parameters (all examples share these)
|
||||||
|
core_group = parser.add_argument_group("Core Parameters")
|
||||||
|
core_group.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default=f"./{self.default_index_name}",
|
||||||
|
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Query to run (if not provided, will run in interactive mode)",
|
||||||
|
)
|
||||||
|
# Allow subclasses to override default max_items
|
||||||
|
max_items_default = getattr(self, "max_items_default", -1)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--max-items",
|
||||||
|
type=int,
|
||||||
|
default=max_items_default,
|
||||||
|
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding parameters
|
||||||
|
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||||
|
# Allow subclasses to override default embedding_model
|
||||||
|
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default=embedding_model_default,
|
||||||
|
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM parameters
|
||||||
|
llm_group = parser.add_argument_group("LLM Parameters")
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="openai",
|
||||||
|
choices=["openai", "ollama", "hf", "simulated"],
|
||||||
|
help="LLM backend: openai, ollama, or hf (default: openai)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--thinking-budget",
|
||||||
|
type=str,
|
||||||
|
choices=["low", "medium", "high"],
|
||||||
|
default=None,
|
||||||
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# AST Chunking parameters
|
||||||
|
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--use-ast-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum characters per AST chunk (default: 512)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-overlap",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Overlap between AST chunks (default: 64)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--code-file-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-fallback-traditional",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search parameters
|
||||||
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
|
search_group.add_argument(
|
||||||
|
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||||
|
)
|
||||||
|
search_group.add_argument(
|
||||||
|
"--search-complexity",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Search complexity for graph traversal (default: 64)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index building parameters
|
||||||
|
index_group = parser.add_argument_group("Index Building Parameters")
|
||||||
|
index_group.add_argument(
|
||||||
|
"--backend-name",
|
||||||
|
type=str,
|
||||||
|
default="hnsw",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
help="Backend to use for index (default: hnsw)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--graph-degree",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Graph degree for index construction (default: 32)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--build-complexity",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Build complexity for index construction (default: 64)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-compact",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable compact index storage",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable embedding recomputation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add source-specific parameters
|
||||||
|
self._add_specific_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add source-specific arguments. Override in subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load data from the source. Returns list of text chunks."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_llm_config(self, args) -> dict[str, Any]:
|
||||||
|
"""Get LLM configuration based on arguments."""
|
||||||
|
config = {"type": args.llm}
|
||||||
|
|
||||||
|
if args.llm == "openai":
|
||||||
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||||
|
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||||
|
if resolved_key:
|
||||||
|
config["api_key"] = resolved_key
|
||||||
|
elif args.llm == "ollama":
|
||||||
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
|
config["host"] = resolve_ollama_host(args.llm_host)
|
||||||
|
elif args.llm == "hf":
|
||||||
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
elif args.llm == "simulated":
|
||||||
|
# Simulated LLM doesn't need additional configuration
|
||||||
|
pass
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def build_index(self, args, texts: list[str]) -> str:
|
||||||
|
"""Build LEANN index from texts."""
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
|
||||||
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend_name,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
complexity=args.build_complexity,
|
||||||
|
is_compact=not args.no_compact,
|
||||||
|
is_recompute=not args.no_recompute,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add texts in batches for better progress tracking
|
||||||
|
batch_size = 1000
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
for text in batch:
|
||||||
|
builder.add_text(text)
|
||||||
|
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||||
|
|
||||||
|
print("Building index structure...")
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"Index saved to: {index_path}")
|
||||||
|
|
||||||
|
# Register project directory so leann list can discover this index
|
||||||
|
# The index is saved as args.index_dir/index_name.leann
|
||||||
|
# We want to register the current working directory where the app is run
|
||||||
|
register_project_directory(Path.cwd())
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def run_interactive_chat(self, args, index_path: str):
|
||||||
|
"""Run interactive chat with the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Interactive Mode] Chat with your {self.name} data!")
|
||||||
|
print("Type 'quit' or 'exit' to stop.\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("You: ").strip()
|
||||||
|
if query.lower() in ["quit", "exit", "q"]:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
|
llm_kwargs = {}
|
||||||
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
)
|
||||||
|
print(f"\nAssistant: {response}\n")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
async def run_single_query(self, args, index_path: str, query: str):
|
||||||
|
"""Run a single query against the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||||
|
|
||||||
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
|
llm_kwargs = {}
|
||||||
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||||
|
)
|
||||||
|
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main entry point for the example."""
|
||||||
|
args = self.parser.parse_args()
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
index_exists = Path(args.index_dir).exists()
|
||||||
|
|
||||||
|
if not index_exists or args.force_rebuild:
|
||||||
|
# Load data and build index
|
||||||
|
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||||
|
texts = await self.load_data(args)
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
print("No data found to index!")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_path = await self.build_index(args, texts)
|
||||||
|
else:
|
||||||
|
print(f"\nUsing existing index in {args.index_dir}")
|
||||||
|
|
||||||
|
# Run query or interactive mode
|
||||||
|
if args.query:
|
||||||
|
await self.run_single_query(args, index_path, args.query)
|
||||||
|
else:
|
||||||
|
await self.run_interactive_chat(args, index_path)
|
||||||
@@ -1,338 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import psutil
|
|
||||||
import gc
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
# Setup logging
|
|
||||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
|
||||||
"""Get current memory usage in MB"""
|
|
||||||
process = psutil.Process()
|
|
||||||
return process.memory_info().rss / 1024 / 1024
|
|
||||||
|
|
||||||
|
|
||||||
def print_memory_stats(stage: str, start_mem: float):
|
|
||||||
"""Print memory statistics"""
|
|
||||||
current_mem = get_memory_usage()
|
|
||||||
diff = current_mem - start_mem
|
|
||||||
print(f"[{stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
|
||||||
return current_mem
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryTracker:
|
|
||||||
def __init__(self, name: str):
|
|
||||||
self.name = name
|
|
||||||
self.start_mem = get_memory_usage()
|
|
||||||
self.stages = []
|
|
||||||
|
|
||||||
def checkpoint(self, stage: str):
|
|
||||||
current_mem = print_memory_stats(f"{self.name} - {stage}", self.start_mem)
|
|
||||||
self.stages.append((stage, current_mem))
|
|
||||||
return current_mem
|
|
||||||
|
|
||||||
def summary(self):
|
|
||||||
print(f"\n=== {self.name} Memory Summary ===")
|
|
||||||
for stage, mem in self.stages:
|
|
||||||
print(f"{stage}: {mem:.1f} MB")
|
|
||||||
peak_mem = max(mem for _, mem in self.stages)
|
|
||||||
print(f"Peak Memory: {peak_mem:.1f} MB")
|
|
||||||
print(f"Total Memory Increase: {peak_mem - self.start_mem:.1f} MB")
|
|
||||||
return peak_mem
|
|
||||||
|
|
||||||
|
|
||||||
def test_faiss_hnsw():
|
|
||||||
"""Test Faiss HNSW Vector Store in subprocess"""
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("TESTING FAISS HNSW VECTOR STORE")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get the directory of this script
|
|
||||||
script_dir = Path(__file__).parent
|
|
||||||
faiss_script = script_dir / "faiss_only.py"
|
|
||||||
result = subprocess.run(
|
|
||||||
[sys.executable, str(faiss_script)],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=300,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(result.stdout)
|
|
||||||
if result.stderr:
|
|
||||||
print("Stderr:", result.stderr)
|
|
||||||
|
|
||||||
if result.returncode != 0:
|
|
||||||
return {
|
|
||||||
"peak_memory": float("inf"),
|
|
||||||
"error": f"Process failed with code {result.returncode}",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Parse peak memory from output
|
|
||||||
lines = result.stdout.split("\n")
|
|
||||||
peak_memory = 0.0
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
if "Peak Memory:" in line:
|
|
||||||
peak_memory = float(
|
|
||||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"peak_memory": peak_memory}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"peak_memory": float("inf"),
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_leann_hnsw():
|
|
||||||
"""Test LEANN HNSW Search Memory (load existing index)"""
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("TESTING LEANN HNSW SEARCH MEMORY")
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
tracker = MemoryTracker("LEANN HNSW Search")
|
|
||||||
|
|
||||||
# Import and setup
|
|
||||||
tracker.checkpoint("Initial")
|
|
||||||
|
|
||||||
from leann.api import LeannSearcher
|
|
||||||
|
|
||||||
tracker.checkpoint("After imports")
|
|
||||||
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
# Load and parse documents
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
"../documents/data",
|
|
||||||
recursive=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
|
||||||
).load_data()
|
|
||||||
|
|
||||||
tracker.checkpoint("After document loading")
|
|
||||||
|
|
||||||
# Parse into chunks
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
tracker.checkpoint("After text chunking")
|
|
||||||
|
|
||||||
# Build LEANN index
|
|
||||||
INDEX_DIR = Path("./test_leann_comparison")
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "comparison.leann")
|
|
||||||
|
|
||||||
# Check if index already exists
|
|
||||||
if os.path.exists(INDEX_PATH + ".meta.json"):
|
|
||||||
print("Loading existing LEANN HNSW index...")
|
|
||||||
tracker.checkpoint("After loading existing index")
|
|
||||||
else:
|
|
||||||
print("Building new LEANN HNSW index...")
|
|
||||||
# Clean up previous index
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
if INDEX_DIR.exists():
|
|
||||||
shutil.rmtree(INDEX_DIR)
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
tracker.checkpoint("After builder setup")
|
|
||||||
|
|
||||||
print("Building LEANN HNSW index...")
|
|
||||||
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
del builder
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
tracker.checkpoint("After index building")
|
|
||||||
|
|
||||||
# Find existing LEANN index
|
|
||||||
index_paths = [
|
|
||||||
"./test_leann_comparison/comparison.leann",
|
|
||||||
]
|
|
||||||
index_path = None
|
|
||||||
for path in index_paths:
|
|
||||||
if os.path.exists(path + ".meta.json"):
|
|
||||||
index_path = path
|
|
||||||
break
|
|
||||||
|
|
||||||
if not index_path:
|
|
||||||
print("❌ LEANN index not found. Please build it first")
|
|
||||||
return {"peak_memory": float("inf"), "error": "Index not found"}
|
|
||||||
|
|
||||||
# Measure runtime memory overhead
|
|
||||||
print("\nMeasuring runtime memory overhead...")
|
|
||||||
runtime_start_mem = get_memory_usage()
|
|
||||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
|
||||||
tracker.checkpoint("Before load memory")
|
|
||||||
|
|
||||||
# Load searcher
|
|
||||||
searcher = LeannSearcher(index_path)
|
|
||||||
tracker.checkpoint("After searcher loading")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("Running search queries...")
|
|
||||||
queries = [
|
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
|
||||||
"What is LEANN and how does it work?",
|
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, query in enumerate(queries):
|
|
||||||
start_time = time.time()
|
|
||||||
# Use same parameters as Faiss: top_k=20, ef=120 (complexity parameter)
|
|
||||||
_ = searcher.search(query, top_k=20, ef=120)
|
|
||||||
query_time = time.time() - start_time
|
|
||||||
print(f"Query {i + 1} time: {query_time:.3f}s")
|
|
||||||
tracker.checkpoint(f"After query {i + 1}")
|
|
||||||
|
|
||||||
runtime_end_mem = get_memory_usage()
|
|
||||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
|
||||||
|
|
||||||
peak_memory = tracker.summary()
|
|
||||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
|
||||||
|
|
||||||
# Get storage size before cleanup
|
|
||||||
storage_size = 0
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
if INDEX_DIR.exists():
|
|
||||||
total_size = 0
|
|
||||||
for dirpath, _, filenames in os.walk(str(INDEX_DIR)):
|
|
||||||
for filename in filenames:
|
|
||||||
# Only count actual index files, skip text data and backups
|
|
||||||
if filename.endswith((".old", ".tmp", ".bak", ".jsonl", ".json")):
|
|
||||||
continue
|
|
||||||
# Count .index, .idx, .map files (actual index structures)
|
|
||||||
if filename.endswith((".index", ".idx", ".map")):
|
|
||||||
filepath = os.path.join(dirpath, filename)
|
|
||||||
total_size += os.path.getsize(filepath)
|
|
||||||
storage_size = total_size / (1024 * 1024) # Convert to MB
|
|
||||||
|
|
||||||
# Clean up
|
|
||||||
del searcher
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"peak_memory": peak_memory,
|
|
||||||
"storage_size": storage_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run comparison tests"""
|
|
||||||
print("Storage + Search Memory Comparison: Faiss HNSW vs LEANN HNSW")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Test Faiss HNSW
|
|
||||||
faiss_results = test_faiss_hnsw()
|
|
||||||
|
|
||||||
# Force garbage collection
|
|
||||||
gc.collect()
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
# Test LEANN HNSW
|
|
||||||
leann_results = test_leann_hnsw()
|
|
||||||
|
|
||||||
# Final comparison
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("STORAGE + SEARCH MEMORY COMPARISON")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Get storage sizes
|
|
||||||
faiss_storage_size = 0
|
|
||||||
leann_storage_size = leann_results.get("storage_size", 0)
|
|
||||||
|
|
||||||
# Get Faiss storage size using Python
|
|
||||||
if os.path.exists("./storage_faiss"):
|
|
||||||
total_size = 0
|
|
||||||
for dirpath, _, filenames in os.walk("./storage_faiss"):
|
|
||||||
for filename in filenames:
|
|
||||||
filepath = os.path.join(dirpath, filename)
|
|
||||||
total_size += os.path.getsize(filepath)
|
|
||||||
faiss_storage_size = total_size / (1024 * 1024) # Convert to MB
|
|
||||||
|
|
||||||
print("Faiss HNSW:")
|
|
||||||
if "error" in faiss_results:
|
|
||||||
print(f" ❌ Failed: {faiss_results['error']}")
|
|
||||||
else:
|
|
||||||
print(f" Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
|
||||||
print(f" Storage Size: {faiss_storage_size:.1f} MB")
|
|
||||||
|
|
||||||
print("\nLEANN HNSW:")
|
|
||||||
if "error" in leann_results:
|
|
||||||
print(f" ❌ Failed: {leann_results['error']}")
|
|
||||||
else:
|
|
||||||
print(f" Search Memory: {leann_results['peak_memory']:.1f} MB")
|
|
||||||
print(f" Storage Size: {leann_storage_size:.1f} MB")
|
|
||||||
|
|
||||||
# Calculate improvements only if both tests succeeded
|
|
||||||
if "error" not in faiss_results and "error" not in leann_results:
|
|
||||||
memory_ratio = faiss_results["peak_memory"] / leann_results["peak_memory"]
|
|
||||||
|
|
||||||
print("\nLEANN vs Faiss Performance:")
|
|
||||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
|
||||||
print(
|
|
||||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Storage comparison
|
|
||||||
if leann_storage_size > faiss_storage_size:
|
|
||||||
storage_ratio = leann_storage_size / faiss_storage_size
|
|
||||||
print(
|
|
||||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
|
||||||
)
|
|
||||||
elif faiss_storage_size > leann_storage_size:
|
|
||||||
storage_ratio = faiss_storage_size / leann_storage_size
|
|
||||||
print(
|
|
||||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(" Storage Size: similar")
|
|
||||||
else:
|
|
||||||
if "error" not in leann_results:
|
|
||||||
print("\n✅ LEANN HNSW completed successfully!")
|
|
||||||
print(f"📊 Search Memory: {leann_results['peak_memory']:.1f} MB")
|
|
||||||
print(f"📊 Storage Size: {leann_storage_size:.1f} MB")
|
|
||||||
if "error" not in faiss_results:
|
|
||||||
print("\n✅ Faiss HNSW completed successfully!")
|
|
||||||
print(f"📊 Search Memory: {faiss_results['peak_memory']:.1f} MB")
|
|
||||||
print(f"📊 Storage Size: {faiss_storage_size:.1f} MB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,151 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""Test only Faiss HNSW"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import psutil
|
|
||||||
import gc
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
|
||||||
process = psutil.Process()
|
|
||||||
return process.memory_info().rss / 1024 / 1024
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryTracker:
|
|
||||||
def __init__(self, name: str):
|
|
||||||
self.name = name
|
|
||||||
self.start_mem = get_memory_usage()
|
|
||||||
self.stages = []
|
|
||||||
|
|
||||||
def checkpoint(self, stage: str):
|
|
||||||
current_mem = get_memory_usage()
|
|
||||||
diff = current_mem - self.start_mem
|
|
||||||
print(f"[{self.name} - {stage}] Memory: {current_mem:.1f} MB (+{diff:.1f} MB)")
|
|
||||||
self.stages.append((stage, current_mem))
|
|
||||||
return current_mem
|
|
||||||
|
|
||||||
def summary(self):
|
|
||||||
peak_mem = max(mem for _, mem in self.stages)
|
|
||||||
print(f"Peak Memory: {peak_mem:.1f} MB")
|
|
||||||
return peak_mem
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
try:
|
|
||||||
import faiss
|
|
||||||
except ImportError:
|
|
||||||
print("Faiss is not installed.")
|
|
||||||
print("Please install it with `uv pip install faiss-cpu`")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
from llama_index.core import (
|
|
||||||
SimpleDirectoryReader,
|
|
||||||
VectorStoreIndex,
|
|
||||||
StorageContext,
|
|
||||||
Settings,
|
|
||||||
node_parser,
|
|
||||||
Document,
|
|
||||||
)
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
||||||
|
|
||||||
tracker = MemoryTracker("Faiss HNSW")
|
|
||||||
tracker.checkpoint("Initial")
|
|
||||||
|
|
||||||
embed_model = HuggingFaceEmbedding(model_name="facebook/contriever")
|
|
||||||
Settings.embed_model = embed_model
|
|
||||||
tracker.checkpoint("After embedding model setup")
|
|
||||||
|
|
||||||
d = 768
|
|
||||||
faiss_index = faiss.IndexHNSWFlat(d, 32)
|
|
||||||
faiss_index.hnsw.efConstruction = 64
|
|
||||||
tracker.checkpoint("After Faiss index creation")
|
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
"../documents/data",
|
|
||||||
recursive=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
|
||||||
).load_data()
|
|
||||||
tracker.checkpoint("After document loading")
|
|
||||||
|
|
||||||
# Parse into chunks using the same splitter as LEANN
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=256, chunk_overlap=20, separator=" ", paragraph_separator="\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
tracker.checkpoint("After text splitter setup")
|
|
||||||
|
|
||||||
# Check if index already exists and try to load it
|
|
||||||
index_loaded = False
|
|
||||||
if os.path.exists("./storage_faiss"):
|
|
||||||
print("Loading existing Faiss HNSW index...")
|
|
||||||
try:
|
|
||||||
# Use the correct Faiss loading pattern from the example
|
|
||||||
vector_store = FaissVectorStore.from_persist_dir("./storage_faiss")
|
|
||||||
storage_context = StorageContext.from_defaults(
|
|
||||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
|
||||||
)
|
|
||||||
from llama_index.core import load_index_from_storage
|
|
||||||
index = load_index_from_storage(storage_context=storage_context)
|
|
||||||
print(f"Index loaded from ./storage_faiss")
|
|
||||||
tracker.checkpoint("After loading existing index")
|
|
||||||
index_loaded = True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to load existing index: {e}")
|
|
||||||
print("Cleaning up corrupted index and building new one...")
|
|
||||||
# Clean up corrupted index
|
|
||||||
import shutil
|
|
||||||
if os.path.exists("./storage_faiss"):
|
|
||||||
shutil.rmtree("./storage_faiss")
|
|
||||||
|
|
||||||
if not index_loaded:
|
|
||||||
print("Building new Faiss HNSW index...")
|
|
||||||
|
|
||||||
# Use the correct Faiss building pattern from the example
|
|
||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
|
||||||
index = VectorStoreIndex.from_documents(
|
|
||||||
documents,
|
|
||||||
storage_context=storage_context,
|
|
||||||
transformations=[node_parser]
|
|
||||||
)
|
|
||||||
tracker.checkpoint("After index building")
|
|
||||||
|
|
||||||
# Save index to disk using the correct pattern
|
|
||||||
index.storage_context.persist(persist_dir="./storage_faiss")
|
|
||||||
tracker.checkpoint("After index saving")
|
|
||||||
|
|
||||||
# Measure runtime memory overhead
|
|
||||||
print("\nMeasuring runtime memory overhead...")
|
|
||||||
runtime_start_mem = get_memory_usage()
|
|
||||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
|
||||||
tracker.checkpoint("Before load memory")
|
|
||||||
|
|
||||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
|
||||||
queries = [
|
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
|
||||||
"What is LEANN and how does it work?",
|
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
|
||||||
]
|
|
||||||
|
|
||||||
for i, query in enumerate(queries):
|
|
||||||
start_time = time.time()
|
|
||||||
_ = query_engine.query(query)
|
|
||||||
query_time = time.time() - start_time
|
|
||||||
print(f"Query {i + 1} time: {query_time:.3f}s")
|
|
||||||
tracker.checkpoint(f"After query {i + 1}")
|
|
||||||
|
|
||||||
runtime_end_mem = get_memory_usage()
|
|
||||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
|
||||||
|
|
||||||
peak_memory = tracker.summary()
|
|
||||||
print(f"Peak Memory: {peak_memory:.1f} MB")
|
|
||||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import argparse
|
|
||||||
try:
|
|
||||||
import dotenv
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
# python-dotenv is not installed; skip loading environment variables
|
|
||||||
dotenv = None
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
# Default Chrome profile path
|
|
||||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
|
||||||
"""
|
|
||||||
Create LEANN index from multiple Chrome profile data sources.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of history entries to process per profile
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
|
||||||
|
|
||||||
# Load documents using ChromeHistoryReader from local readers module
|
|
||||||
from .readers import ChromeHistoryReader
|
|
||||||
reader = ChromeHistoryReader()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
# Process each Chrome profile directory
|
|
||||||
for i, profile_dir in enumerate(profile_dirs):
|
|
||||||
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
documents = reader.load_data(
|
|
||||||
chrome_profile_path=str(profile_dir),
|
|
||||||
max_count=max_count
|
|
||||||
)
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
|
|
||||||
# Check if we've reached the max count
|
|
||||||
if max_count > 0 and total_processed >= max_count:
|
|
||||||
print(f"Reached max count of {max_count} documents")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {profile_dir}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {profile_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1 # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
|
||||||
"""
|
|
||||||
Query the LEANN index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to the LEANN index
|
|
||||||
query: The query string
|
|
||||||
"""
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=index_path)
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=10,
|
|
||||||
recompute_beighbor_embeddings=True,
|
|
||||||
complexity=32,
|
|
||||||
beam_width=1,
|
|
||||||
llm_config={
|
|
||||||
"type": "openai",
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
llm_kwargs={
|
|
||||||
"temperature": 0.0,
|
|
||||||
"max_tokens": 1000
|
|
||||||
}
|
|
||||||
)
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# Parse command line arguments
|
|
||||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
|
||||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
|
||||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
|
||||||
parser.add_argument('--index-dir', type=str, default="./chrome_history_index_leann_test",
|
|
||||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
|
||||||
parser.add_argument('--max-entries', type=int, default=1000,
|
|
||||||
help='Maximum number of history entries to process (default: 1000)')
|
|
||||||
parser.add_argument('--query', type=str, default=None,
|
|
||||||
help='Single query to run (default: runs example queries)')
|
|
||||||
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
|
||||||
help='Automatically find all Chrome profiles (default: True)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
|
||||||
|
|
||||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
|
||||||
print(f"Max entries: {args.max_entries}")
|
|
||||||
|
|
||||||
# Find Chrome profile directories
|
|
||||||
from .readers import ChromeHistoryReader
|
|
||||||
|
|
||||||
if args.auto_find_profiles:
|
|
||||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
|
||||||
if not profile_dirs:
|
|
||||||
print("No Chrome profiles found automatically. Exiting.")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# Use single specified profile
|
|
||||||
profile_path = Path(args.chrome_profile)
|
|
||||||
if not profile_path.exists():
|
|
||||||
print(f"Chrome profile not found: {profile_path}")
|
|
||||||
return
|
|
||||||
profile_dirs = [profile_path]
|
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
|
||||||
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
|
||||||
|
|
||||||
if index_path:
|
|
||||||
if args.query:
|
|
||||||
# Run single query
|
|
||||||
await query_leann_index(index_path, args.query)
|
|
||||||
else:
|
|
||||||
# Example queries
|
|
||||||
queries = [
|
|
||||||
"What websites did I visit about machine learning?",
|
|
||||||
"Find my search history about programming"
|
|
||||||
]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print("\n" + "="*60)
|
|
||||||
await query_leann_index(index_path, query)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
import sqlite3
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
class ChromeHistoryReader(BaseReader):
|
|
||||||
"""
|
|
||||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
|
||||||
|
|
||||||
Reads Chrome history from the default Chrome profile location and creates documents
|
|
||||||
with embedded metadata similar to the email reader structure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Load Chrome history data from the default Chrome profile location.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Not used for Chrome history (kept for compatibility)
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum amount of history entries to read.
|
|
||||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
|
||||||
"""
|
|
||||||
docs: List[Document] = []
|
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
|
||||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
|
||||||
|
|
||||||
# Default Chrome profile path on macOS
|
|
||||||
if chrome_profile_path is None:
|
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
|
||||||
|
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
|
||||||
|
|
||||||
if not os.path.exists(history_db_path):
|
|
||||||
print(f"Chrome history database not found at: {history_db_path}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Connect to the Chrome history database
|
|
||||||
print(f"Connecting to database: {history_db_path}")
|
|
||||||
conn = sqlite3.connect(history_db_path)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Query to get browsing history with metadata (removed created_time column)
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
|
||||||
url,
|
|
||||||
title,
|
|
||||||
visit_count,
|
|
||||||
typed_count,
|
|
||||||
hidden
|
|
||||||
FROM urls
|
|
||||||
ORDER BY last_visit_time DESC
|
|
||||||
"""
|
|
||||||
|
|
||||||
print(f"Executing query on database: {history_db_path}")
|
|
||||||
cursor.execute(query)
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
print(f"Query returned {len(rows)} rows")
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for row in rows:
|
|
||||||
if count >= max_count and max_count > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
|
||||||
doc_content = f"""
|
|
||||||
[BROWSING HISTORY METADATA]
|
|
||||||
URL: {url}
|
|
||||||
Title: {title}
|
|
||||||
Last Visit: {last_visit}
|
|
||||||
Visit Count: {visit_count}
|
|
||||||
Typed Count: {typed_count}
|
|
||||||
Hidden: {hidden}
|
|
||||||
[END METADATA]
|
|
||||||
|
|
||||||
Title: {title}
|
|
||||||
URL: {url}
|
|
||||||
Last visited: {last_visit}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create document with embedded metadata
|
|
||||||
doc = Document(text=doc_content, metadata={})
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
print(f"Loaded {len(docs)} Chrome history documents")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading Chrome history: {e}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def find_chrome_profiles() -> List[Path]:
|
|
||||||
"""
|
|
||||||
Find all Chrome profile directories.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Path objects pointing to Chrome profile directories
|
|
||||||
"""
|
|
||||||
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
|
||||||
profile_dirs = []
|
|
||||||
|
|
||||||
if not chrome_base_path.exists():
|
|
||||||
print(f"Chrome directory not found at: {chrome_base_path}")
|
|
||||||
return profile_dirs
|
|
||||||
|
|
||||||
# Find all profile directories
|
|
||||||
for profile_dir in chrome_base_path.iterdir():
|
|
||||||
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
|
||||||
history_path = profile_dir / "History"
|
|
||||||
if history_path.exists():
|
|
||||||
profile_dirs.append(profile_dir)
|
|
||||||
print(f"Found Chrome profile: {profile_dir}")
|
|
||||||
|
|
||||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
|
||||||
return profile_dirs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
|
||||||
"""
|
|
||||||
Export Chrome history to a text file using the same SQL query format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_file: Path to the output file
|
|
||||||
max_count: Maximum number of entries to export
|
|
||||||
"""
|
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
|
||||||
|
|
||||||
if not os.path.exists(history_db_path):
|
|
||||||
print(f"Chrome history database not found at: {history_db_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn = sqlite3.connect(history_db_path)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
query = """
|
|
||||||
SELECT
|
|
||||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
|
||||||
url,
|
|
||||||
title,
|
|
||||||
visit_count,
|
|
||||||
typed_count,
|
|
||||||
hidden
|
|
||||||
FROM urls
|
|
||||||
ORDER BY last_visit_time DESC
|
|
||||||
LIMIT ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
cursor.execute(query, (max_count,))
|
|
||||||
rows = cursor.fetchall()
|
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
for row in rows:
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
|
||||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error exporting Chrome history: {e}")
|
|
||||||
171
apps/browser_rag.py
Normal file
171
apps/browser_rag.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
Browser History RAG example using the unified interface.
|
||||||
|
Supports Chrome browser history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Chrome browser history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Browser History",
|
||||||
|
description="Process and query Chrome browser history with LEANN",
|
||||||
|
default_index_name="google_history_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add browser-specific arguments."""
|
||||||
|
browser_group = parser.add_argument_group("Browser Parameters")
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chrome-profile",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--auto-find-profiles",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Automatically find all Chrome profiles (default: True)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_chrome_base_path(self) -> Path:
|
||||||
|
"""Get the base Chrome profile path based on OS."""
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||||
|
elif sys.platform.startswith("linux"):
|
||||||
|
return Path.home() / ".config" / "google-chrome"
|
||||||
|
elif sys.platform == "win32":
|
||||||
|
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||||
|
|
||||||
|
def _find_chrome_profiles(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Chrome profiles."""
|
||||||
|
base_path = self._get_chrome_base_path()
|
||||||
|
if not base_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
profiles = []
|
||||||
|
|
||||||
|
# Check Default profile
|
||||||
|
default_profile = base_path / "Default"
|
||||||
|
if default_profile.exists() and (default_profile / "History").exists():
|
||||||
|
profiles.append(default_profile)
|
||||||
|
|
||||||
|
# Check numbered profiles
|
||||||
|
for item in base_path.iterdir():
|
||||||
|
if item.is_dir() and item.name.startswith("Profile "):
|
||||||
|
if (item / "History").exists():
|
||||||
|
profiles.append(item)
|
||||||
|
|
||||||
|
return profiles
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load browser history and convert to text chunks."""
|
||||||
|
# Determine Chrome profiles
|
||||||
|
if args.chrome_profile and not args.auto_find_profiles:
|
||||||
|
profile_dirs = [Path(args.chrome_profile)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Chrome profiles...")
|
||||||
|
profile_dirs = self._find_chrome_profiles()
|
||||||
|
|
||||||
|
# If specific profile given, filter to just that one
|
||||||
|
if args.chrome_profile:
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||||
|
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found!")
|
||||||
|
print("Please specify --chrome-profile manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
# Process each profile
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per profile
|
||||||
|
max_per_profile = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_profile = remaining
|
||||||
|
|
||||||
|
# Load history
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir),
|
||||||
|
max_count=max_per_profile,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} history entries from this profile")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No browser history found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for browser history RAG
|
||||||
|
print("\n🌐 Browser History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What websites did I visit about machine learning?'")
|
||||||
|
print("- 'Find my search history about programming'")
|
||||||
|
print("- 'What YouTube videos did I watch recently?'")
|
||||||
|
print("- 'Show me websites about travel planning'")
|
||||||
|
print("\nNote: Make sure Chrome is closed before running\n")
|
||||||
|
|
||||||
|
rag = BrowserRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
44
apps/chunking/__init__.py
Normal file
44
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Unified chunking utilities facade.
|
||||||
|
|
||||||
|
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||||
|
that both repo apps (importing `chunking`) and installed wheels share one
|
||||||
|
single implementation. When running from the repo without installation, it
|
||||||
|
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||||
|
if leann_src.exists():
|
||||||
|
sys.path.insert(0, str(leann_src))
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CODE_EXTENSIONS",
|
||||||
|
"create_ast_chunks",
|
||||||
|
"create_text_chunks",
|
||||||
|
"create_traditional_chunks",
|
||||||
|
"detect_code_files",
|
||||||
|
"get_language_from_extension",
|
||||||
|
]
|
||||||
211
apps/code_rag.py
Normal file
211
apps/code_rag.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||||
|
Specialized for code repositories with automatic language detection and
|
||||||
|
optimized chunking parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRAG(BaseRAGExample):
|
||||||
|
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Code",
|
||||||
|
description="Process and query code repositories with AST-aware chunking",
|
||||||
|
default_index_name="code_index",
|
||||||
|
)
|
||||||
|
# Override defaults for code-specific usage
|
||||||
|
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||||
|
self.max_items_default = -1 # Process all code files by default
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add code-specific arguments."""
|
||||||
|
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||||
|
|
||||||
|
code_group.add_argument(
|
||||||
|
"--repo-dir",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Code repository directory to index (default: current directory)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=list(CODE_EXTENSIONS.keys()),
|
||||||
|
help="File extensions to include (default: supported code extensions)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--exclude-dirs",
|
||||||
|
nargs="+",
|
||||||
|
default=[
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
".venv",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"target",
|
||||||
|
],
|
||||||
|
help="Directories to exclude from indexing",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--max-file-size",
|
||||||
|
type=int,
|
||||||
|
default=1000000, # 1MB
|
||||||
|
help="Maximum file size in bytes to process (default: 1MB)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-comments",
|
||||||
|
action="store_true",
|
||||||
|
help="Include comments in chunking (useful for documentation)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--preserve-imports",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Try to preserve import statements in chunks (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load code files and convert to AST-aware chunks."""
|
||||||
|
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||||
|
print(f"📁 Including extensions: {args.include_extensions}")
|
||||||
|
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||||
|
|
||||||
|
# Check if repository directory exists
|
||||||
|
repo_path = Path(args.repo_dir)
|
||||||
|
if not repo_path.exists():
|
||||||
|
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||||
|
|
||||||
|
# Load code files with filtering
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"required_exts": args.include_extensions,
|
||||||
|
"exclude_hidden": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create exclusion filter
|
||||||
|
def file_filter(file_path: str) -> bool:
|
||||||
|
"""Filter out unwanted files and directories."""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# Check file size
|
||||||
|
try:
|
||||||
|
if path.stat().st_size > args.max_file_size:
|
||||||
|
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if in excluded directory
|
||||||
|
for exclude_dir in args.exclude_dirs:
|
||||||
|
if exclude_dir in path.parts:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load documents with file filtering
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.repo_dir,
|
||||||
|
file_extractor=None, # Use default extractors
|
||||||
|
**reader_kwargs,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
|
# Apply custom filtering
|
||||||
|
filtered_docs = []
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_filter(file_path):
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
documents = filtered_docs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error loading code files: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(
|
||||||
|
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"✅ Loaded {len(documents)} code files")
|
||||||
|
|
||||||
|
# Show breakdown by language/extension
|
||||||
|
ext_counts = {}
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||||
|
|
||||||
|
print("📊 Files by extension:")
|
||||||
|
for ext, count in sorted(ext_counts.items()):
|
||||||
|
print(f" {ext}: {count} files")
|
||||||
|
|
||||||
|
# Use AST-aware chunking by default for code
|
||||||
|
print(
|
||||||
|
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=256, # Fallback for non-code files
|
||||||
|
chunk_overlap=64,
|
||||||
|
use_ast_chunking=True, # Always use AST for code RAG
|
||||||
|
ast_chunk_size=args.ast_chunk_size,
|
||||||
|
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||||
|
code_file_extensions=args.include_extensions,
|
||||||
|
ast_fallback_traditional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for code RAG
|
||||||
|
print("\n💻 Code RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'How does the embedding computation work?'")
|
||||||
|
print("- 'What are the main classes in this codebase?'")
|
||||||
|
print("- 'Show me the search implementation'")
|
||||||
|
print("- 'How is error handling implemented?'")
|
||||||
|
print("- 'What design patterns are used?'")
|
||||||
|
print("- 'Explain the chunking logic'")
|
||||||
|
print("\n🚀 Features:")
|
||||||
|
print("- ✅ AST-aware chunking preserves code structure")
|
||||||
|
print("- ✅ Automatic language detection")
|
||||||
|
print("- ✅ Smart filtering of large files and common excludes")
|
||||||
|
print("- ✅ Optimized for code understanding")
|
||||||
|
print("\nUsage examples:")
|
||||||
|
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||||
|
print(
|
||||||
|
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||||
|
)
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = CodeRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
131
apps/document_rag.py
Normal file
131
apps/document_rag.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
Document RAG example using the unified interface.
|
||||||
|
Supports PDF, TXT, MD, and other document formats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRAG(BaseRAGExample):
|
||||||
|
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Document",
|
||||||
|
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||||
|
default_index_name="test_doc_files",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add document-specific arguments."""
|
||||||
|
doc_group = parser.add_argument_group("Document Parameters")
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
type=str,
|
||||||
|
default="data",
|
||||||
|
help="Directory containing documents to index (default: data)",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--file-types",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--enable-code-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files in the data directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load documents and convert to text chunks."""
|
||||||
|
print(f"Loading documents from: {args.data_dir}")
|
||||||
|
if args.file_types:
|
||||||
|
print(f"Filtering by file types: {args.file_types}")
|
||||||
|
else:
|
||||||
|
print("Processing all supported file types")
|
||||||
|
|
||||||
|
# Check if data directory exists
|
||||||
|
data_path = Path(args.data_dir)
|
||||||
|
if not data_path.exists():
|
||||||
|
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
}
|
||||||
|
if args.file_types:
|
||||||
|
reader_kwargs["required_exts"] = args.file_types
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} documents")
|
||||||
|
|
||||||
|
# Determine chunking strategy
|
||||||
|
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||||
|
|
||||||
|
if use_ast:
|
||||||
|
print("Using AST-aware chunking for code files")
|
||||||
|
|
||||||
|
# Convert to text chunks with optional AST support
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
use_ast_chunking=use_ast,
|
||||||
|
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||||
|
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||||
|
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||||
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for document RAG
|
||||||
|
print("\n📄 Document RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What are the main techniques LEANN uses?'")
|
||||||
|
print("- 'What is the technique DLPM?'")
|
||||||
|
print("- 'Who does Elizabeth Bennet marry?'")
|
||||||
|
print(
|
||||||
|
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||||
|
)
|
||||||
|
print("\n🚀 NEW: Code-aware chunking available!")
|
||||||
|
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||||
|
print("- Supports Python, Java, C#, TypeScript files")
|
||||||
|
print("- Better semantic understanding of code structure")
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = DocumentRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from llama_index.core import SimpleDirectoryReader
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
import asyncio
|
|
||||||
import dotenv
|
|
||||||
from leann.api import LeannBuilder, LeannChat
|
|
||||||
from pathlib import Path
|
|
||||||
import os
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
async def main(args):
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "pdf_documents.leann")
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
node_parser = SentenceSplitter(
|
|
||||||
chunk_size=256, chunk_overlap=128, separator=" ", paragraph_separator="\n\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Loading documents...")
|
|
||||||
# Get the data directory relative to this module
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
data_dir = current_dir / "data"
|
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
|
||||||
str(data_dir),
|
|
||||||
recursive=True,
|
|
||||||
encoding="utf-8",
|
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
|
||||||
).load_data(show_progress=True)
|
|
||||||
print("Documents loaded.")
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
print("--- Index directory not found, building new index ---")
|
|
||||||
|
|
||||||
print("\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1, # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Loaded {len(all_texts)} text chunks from documents.")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(INDEX_PATH)
|
|
||||||
print(f"\nLeann index built at {INDEX_PATH}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
|
|
||||||
# llm_config = {"type": "hf", "model": "Qwen/Qwen3-4B"}
|
|
||||||
llm_config = {"type": "ollama", "model": "qwen3:8b"}
|
|
||||||
|
|
||||||
chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)
|
|
||||||
|
|
||||||
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead and DLPM explore to achieve Fairness and Efiiciency trade-off?"
|
|
||||||
|
|
||||||
# query = (
|
|
||||||
# "什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发"
|
|
||||||
# )
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(query, top_k=20, recompute_embeddings=True, complexity=32)
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run Leann Chat with various LLM backends."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--llm",
|
|
||||||
type=str,
|
|
||||||
default="hf",
|
|
||||||
choices=["simulated", "ollama", "hf", "openai"],
|
|
||||||
help="The LLM backend to use.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
type=str,
|
|
||||||
default="Qwen/Qwen3-0.6B",
|
|
||||||
help="The model name to use (e.g., 'llama3:8b' for ollama, 'deepseek-ai/deepseek-llm-7b-chat' for hf, 'gpt-4o' for openai).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--host",
|
|
||||||
type=str,
|
|
||||||
default="http://localhost:11434",
|
|
||||||
help="The host for the Ollama API.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--index-dir",
|
|
||||||
type=str,
|
|
||||||
default="./test_doc_files",
|
|
||||||
help="Directory where the Leann index will be stored.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
asyncio.run(main(args))
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
# 盘古之殇:华为诺亚盘古大模型研发历程的心酸与黑暗
|
|
||||||
|
|
||||||
各位好,
|
|
||||||
|
|
||||||
我是一名盘古大模型团队,华为诺亚方舟实验室的员工。
|
|
||||||
|
|
||||||
首先为自证身份,列举一些细节:
|
|
||||||
|
|
||||||
1. 现诺亚主任,前算法应用部部长,后改名为小模型实验室的主任王云鹤。前诺亚主任:姚骏(大家称姚老师)。几个实验室主任:唐睿明(明哥,明队,已离职),尚利峰,张维(维哥),郝建业(郝老师),刘武龙(称呼为武龙所)等。其他骨干成员和专家陆续有很多人离职。
|
|
||||||
2. 我们隶属于“四野”这个组织。四野下属有许多纵队,基础语言大模型是四纵。王云鹤的小模型是十六纵队。我们参加过苏州的集结,有各种月份的时间节点。在苏州攻关会颁发任务令,需要在节点前达成目标。苏州集结会把各地的人员都集中在苏州研究所,平常住宾馆,比如在甪直的酒店,与家人孩子天各一方。
|
|
||||||
3. 在苏州集结的时候周六默认上班,非常辛苦,不过周六有下午茶,有一次还有小龙虾。在苏州研究所的工位搬迁过一次,从一栋楼换到了另一栋。苏州研究所楼栋都是欧式装修,门口有大坡,里面景色很不错。去苏州集结一般至少要去一周,甚至更久,多的人甚至一两个月都回不了家。
|
|
||||||
4. 诺亚曾经传说是研究型的,但是来了之后因为在四野做大模型项目,项目成员完全变成了交付型的,且充满了例会,评审,汇报。很多时候做实验都要申请。团队需要对接终端小艺,华为云,ICT等诸多业务线,交付压力不小。
|
|
||||||
5. 诺亚研发的盘古模型早期内部代号叫做“盘古智子”,一开始只有内部需要申请试用的网页版,到后续迫于压力在welink上接入和公测开放。
|
|
||||||
|
|
||||||
这些天发生关于质疑盘古大模型抄袭千问的事情闹的沸沸扬扬。作为一个盘古团队的成员,我最近夜夜辗转反侧,难以入眠。盘古的品牌受到如此大的影响,一方面,我自私的为我的职业发展担忧,也为自己过去的努力工作感到不值。另一方面,由于有人开始揭露这些事情我内心又感到大快人心。在多少个日日夜夜,我们对内部某些人一次次靠着造假而又获得了无数利益的行为咬牙切齿而又无能为力。这种压抑和羞辱也逐渐消磨了我对华为的感情,让我在这里的时日逐渐浑浑噩噩,迷茫无措,时常怀疑自己的人生和自我价值。
|
|
||||||
|
|
||||||
我承认我是一个懦弱的人,作为一个小小的打工人,我不仅不敢和王云鹤等内部手眼通天的人做对,更不敢和华为这样的庞然大物做对。我很怕失去我的工作,毕竟我也有家人和孩子,所以我打心眼里很佩服揭露者。但是,看到内部还在试图洗地掩盖事实,蒙蔽公众的时候,我实在不能容忍了。我也希望勇敢一次,顺从自己本心。就算自损八百,我也希望能伤敌一千。我决定把我在这里的所见所闻(部分来自于同事口述)公布出来,关于盘古大模型的“传奇故事”:
|
|
||||||
|
|
||||||
华为确实主要在昇腾卡上训练大模型(小模型实验室有不少英伟达的卡,他们之前也会用来训练,后面转移到昇腾)。曾经我被华为“打造世界第二选择”的决心而折服,我本身也曾经对华为有深厚的感情。我们陪着昇腾一步步摸爬滚打,从充满bug到现在能训出模型,付出了巨大的心血和代价。
|
|
||||||
|
|
||||||
最初我们的算力非常有限,在910A上训练模型。那会只支持fp16,训练的稳定性远不如bf16。盘古的moe开始很早,23年就主要是训练38Bmoe模型和后续的71B dense模型。71B的dense模型通过扩增变成了第一代的135Bdense模型,后面主力模型也逐渐在910B上训练。
|
|
||||||
|
|
||||||
71B和135B模型都有一个巨大的硬伤就是tokenizer。当时使用的tokenizer编码效率极低,每个单个的符号,数字,空格,乃至汉字都会占用一个token。可想而知这会非常浪费算力,且使得模型的效果很差。这时候小模型实验室正好有个自己训的词表。姚老师当时怀疑是不是模型的tokenizer不好(虽然事后来看,他的怀疑是无疑正确的),于是就决定,让71B和135B换tokenizer,因为小模型实验室曾经尝试过。团队缝合了两个tokenizer,开始了tokenizer的更换。71B模型的更换失败了,而135B因为采用了更精细的embedding初始化策略,续训了至少1T的数据后词表总算更换成功,但可想而知,效果并不会变好。
|
|
||||||
|
|
||||||
于此同期,阿里和智谱等国内其他公司在GPU上训练,且已经摸索出了正确的方法,盘古和竞品的差距越来越大。内部一个230B从头训练的dense模型又因为各种原因训练失败,导致项目的状况几乎陷入绝境。面临几个节点的压力以及内部对盘古的强烈质疑时,团队的士气低迷到了极点。团队在算力极其有限的时候,做出了很多努力和挣扎。比如,团队偶然发现当时的38B moe并没有预期moe的效果。于是去掉了moe参数,还原为了13B的dense模型。由于38B的moe源自很早的pangu alpha 13B,架构相对落后,团队进行了一系列的操作,比如切换绝对位置编码到rope,去掉bias,切换为rmsnorm。同时鉴于tokenizer的一些失败和换词表的经验,这个模型的词表也更换为了王云鹤的小模型实验室7B模型所使用的词表。后面这个13B模型进行了扩增续训,变成了第二代38B dense模型(在几个月内这个模型都是主要的盘古中档位模型),曾经具有一定的竞争力。但是,由于更大的135B模型架构落后,且更换词表模型损伤巨大(后续分析发现当时更换的缝合词表有更严重的bug),续训后也与千问等当时国内领先模型存在很大差距。这时由于内部的质疑声和领导的压力也越来越大。团队的状态几乎陷入了绝境。
|
|
||||||
|
|
||||||
在这种情况下,王云鹤和他的小模型实验室出手了。他们声称是从旧的135B参数继承改造而来,通过训练短短的几百B数据,各项指标平均提升了十个点左右。实际上,这就是他们套壳应用到大模型的第一次杰作。华为的外行领导内行,使得领导完全对于这种扯淡的事情没有概念,他们只会觉得肯定是有什么算法创新。经过内部的分析,他们实际上是使用Qwen 1.5 110B续训而来,通过加层,扩增ffn维度,添加盘古pi论文的一些机制得来,凑够了大概135B的参数。实际上,旧的135B有107层,而这个模型只有82层,各种配置也都不一样。新的来路不明的135B训练完很多参数的分布也和Qwen 110B几乎一模一样。连模型代码的类名当时都是Qwen,甚至懒得改名。后续这个模型就是所谓的135B V2。而这个模型当时也提供给了很多下游,甚至包括外部客户。
|
|
||||||
|
|
||||||
这件事对于我们这些认真诚实做事的同事们带来了巨大的冲击,内部很多人其实都知道这件事,甚至包括终端和华为云。我们都戏称以后别叫盘古模型了,叫千古吧。当时团队成员就想向bcg举报了,毕竟这已经是重大的业务造假了。但是后面据说被领导拦了下来,因为更高级别的领导(比如姚老师,以及可能熊总和查老)其实后面也知道了,但是并不管,因为通过套壳拿出好的结果,对他们也是有利的。这件事使得当时团队几位最强的同事开始心灰意冷,离职跑路也逐渐成为挂在嘴边的事。
|
|
||||||
|
|
||||||
此时,盘古似乎迎来了转机。由于前面所述的这些盘古模型基本都是续训和改造而来,当时诺亚完全没有掌握从头训练的技术,何况还是在昇腾的NPU上进行训练。在当时团队的核心成员的极力争取下,盘古开始了第三代模型的训练,付出了巨大的努力后,在数据架构和训练算法方面都与业界逐渐接轨,而这其中的艰辛和小模型实验室的人一点关系都没有。
|
|
||||||
|
|
||||||
一开始团队成员毫无信心,只从一个13B的模型开始训练,但是后面发现效果还不错,于是这个模型后续再次进行了一次参数扩增,变成了第三代的38B,代号38B V3。想必很多产品线的兄弟都对这个模型很熟悉。当时这个模型的tokenizer是基于llama的词表进行扩展的(也是业界常见的做法)。而当时王云鹤的实验室做出来了另一个词表(也就是后续pangu系列的词表)。当时两个词表还被迫进行了一次赛马,最终没有明显的好坏结论。于是,领导当即决定,应该统一词表,使用王云鹤他们的。于是,在后续从头训练的135B V3(也就是对外的Pangu Ultra),便是采用了这个tokenizer。这也解释了很多使用我们模型的兄弟的疑惑,为什么当时同为V3代的两个不同档位的模型,会使用不同的tokenizer。
|
|
||||||
|
|
||||||
|
|
||||||
我们打心眼里觉得,135B V3是我们四纵团队当时的骄傲。这是第一个真正意义上的,华为全栈自研,正经从头训练的千亿级别的模型,且效果与24年同期竞品可比的。写到这里我已经热泪盈眶,太不容易了。当时为了稳定训练,团队做了大量实验对比,并且多次在模型梯度出现异常的时候进行及时回退重启。这个模型真正做到了后面技术报告所说的训练全程没有一个loss spike。我们克服了不知道多少困难,我们做到了,我们愿用生命和荣誉保证这个模型训练的真实性。多少个凌晨,我们为了它的训练而不眠。在被内部心声骂的一文不值的时候,我们有多么不甘,有多少的委屈,我们挺住了。
|
|
||||||
|
|
||||||
我们这帮人是真的在为打磨国产算力底座燃烧自己的青春啊……客居他乡,我们放弃了家庭,放弃了假期,放弃了健康,放弃了娱乐,抛头颅洒热血,其中的艰辛与困苦,寥寥数笔不足以概括其万一。在各种动员大会上,当时口号中喊出的盘古必胜,华为必胜,我们心里是真的深深被感动。
|
|
||||||
|
|
||||||
然而,我们的所有辛苦的成果,经常被小模型实验室轻飘飘的拿走了。数据,直接要走。代码,直接要走,还要求我们配合适配到能一键运行。我们当时戏称小模型实验室为点鼠标实验室。我们付出辛苦,他们取得荣耀。果然应了那句话,你在负重前行是因为有人替你岁月静好。在这种情况下,越来越多的战友再也坚持不下去了,选择了离开。看到身边那些优秀的同事一个个离职,我的内心又感叹又难过。在这种作战一样的环境下,我们比起同事来说更像是战友。他们在技术上也有无数值得我学习的地方,堪称良师。看到他们去了诸如字节Seed,Deepseek,月之暗面,腾讯和快手等等很多出色的团队,我打心眼里为他们高兴和祝福,脱离了这个辛苦却肮脏的地方。我至今还对一位离职同事的话记忆犹新,ta说:“来这里是我技术生涯中的耻辱,在这里再呆每一天都是浪费生命”。话虽难听却让我无言以对。我担心我自己技术方面的积累不足,以及没法适应互联网公司高淘汰的环境,让我多次想离职的心始终没有迈出这一步。
|
|
||||||
|
|
||||||
盘古除了dense模型,后续也启动了moe的探索。一开始训练的是一个224B的moe模型。而与之平行的,小模型实验室也开启了第二次主要的套壳行动(次要的插曲可能还包括一些别的模型,比如math模型),即这次流传甚广的pangu pro moe 72B。这个模型内部自称是从小模型实验室的7B扩增上来的(就算如此,这也与技术报告不符,何况是套壳qwen 2.5的14b续训)。还记得他们训了没几天,内部的评测就立刻追上了当时的38B V3。AI系统实验室很多兄弟因为需要适配模型,都知道他们的套壳行动,只是迫于各种原因,无法伸张正义。实际上,对于后续训了很久很久的这个模型,Honestagi能够分析出这个量级的相似性我已经很诧异了,因为这个模型为了续训洗参数,所付出的算力甚至早就足够从头训一个同档位的模型了。听同事说他们为了洗掉千问的水印,采取了不少办法,甚至包括故意训了脏数据。这也为学术界研究模型血缘提供了一个前所未有的特殊模范吧。以后新的血缘方法提出可以拿出来溜溜。
|
|
||||||
|
|
||||||
24年底和25年初,在Deepseek v3和r1发布之后,由于其惊艳的技术水平,团队受到了巨大的冲击,也受到了更大的质疑。于是为了紧跟潮流,盘古模仿Deepseek的模型尺寸,开启了718B moe的训练。这个时候,小模型实验室再次出手了。他们选择了套壳Deepseekv3续训。他们通过冻住Deepseek加载的参数,进行训练。连任务加载ckpt的目录都是deepseekv3,改都不改,何其嚣张?与之相反,一些有真正技术信仰的同事,在从头训练另一个718B的moe。但其中出现了各种各样的问题。但是很显然,这个模型怎么可能比直接套壳的好呢?如果不是团队leader坚持,早就被叫停了。
|
|
||||||
|
|
||||||
华为的流程管理之繁重,严重拖累了大模型的研发节奏,例如版本管理,模型血缘,各种流程化,各种可追溯。讽刺的是,小模型实验室的模型似乎从来不受这些流程的约束,想套壳就套壳,想续训就续训,算力源源不断的伸手拿走。这种强烈到近乎魔幻的对比,说明了当前流程管理的情况:只许州官放火,不许百姓点灯。何其可笑?何其可悲?何其可恶?何其可耻!
|
|
||||||
|
|
||||||
HonestAGI的事情出来后,内部让大家不停的研讨分析,如何公关和“回应”。诚然,这个原文的分析也许不够有力,给了王云鹤与小模型实验室他们狡辩和颠倒黑白的机会。为此,这两天我内心感到作呕,时时怀疑自己的人生意义以及苍天无眼。我不奉陪了,我要离职了,同时我也在申请从盘古部分技术报告的作者名单中移除。曾经在这些技术报告上署名是我一生都无法抹除的污点。当时我没想到,他们竟然猖狂到敢开源。我没想到,他们敢如此愚弄世人,大肆宣发。当时,我也许是存了侥幸心理,没有拒绝署名。我相信很多扎实做事的战友,也只是被迫上了贼船,或者不知情。但这件事已经无法挽回,我希望我的余生能够坚持扎实做真正有意义的事,为我当时的软弱和不坚定赎罪。
|
|
||||||
|
|
||||||
深夜写到这里,我已经泪流满面,泣不成声。还记得一些出色的同事离职时,我苦笑问他们要不要发个长长的心声惯例帖,揭露一下现状。对方说:不了,浪费时间,而且我也怕揭露出来你们过的更糟。我当时一下黯然神伤,因为曾经共同为了理想奋斗过的战友已经彻底对华为彻底灰心了。当时大家调侃,我们用着当年共产党的小米加步枪,组织却有着堪比当年国民党的作风。
|
|
||||||
|
|
||||||
曾几何时,我为我们用着小米加步枪打败洋枪洋炮而自豪。
|
|
||||||
|
|
||||||
现在,我累了,我想投降。
|
|
||||||
|
|
||||||
其实时至今日,我还是真心希望华为能认真吸取教训,能做好盘古,把盘古做到世界一流,把昇腾变成英伟达的水平。内部的劣币驱逐良币,使得诺亚乃至华为在短时间内急剧流失了大量出色的大模型人才。相信他们也正在如Deepseek等各个团队闪耀着,施展着他们的抱负才华,为中美在AI的激烈竞赛中奉献力量。我时常感叹,华为不是没有人才,而是根本不知道怎么留住人才。如果给这些人合适的环境,合适的资源,更少的枷锁,更少的政治斗争,盘古何愁不成?
|
|
||||||
|
|
||||||
最后:我以生命,人格和荣誉发誓,我写的以上所有内容均为真实(至少在我有限的认知范围内)。我没有那么高的技术水平以及机会去做详尽扎实的分析,也不敢直接用内部记录举证,怕因为信息安全抓到。但是我相信我很多曾经的战友,会为我作证。在华为内部的兄弟,包括我们曾经服务过的产品线兄弟们,相信本文的无数细节能和你们的印象对照,印证我的说法。你们可能也曾经被蒙骗,但这些残酷的真相不会被尘封。我们奋战过的痕迹,也不应该被扭曲和埋葬。
|
|
||||||
|
|
||||||
写了这么多,某些人肯定想把我找出来,抹杀掉。公司搞不好也想让我噤声乃至追责。如果真的这样,我,乃至我的家人的人身乃至生命安全可能都会受到威胁。为了自我保护,我近期每天会跟大家报平安。
|
|
||||||
|
|
||||||
如果我消失了,就当是我为了真理和理想,为了华为乃至中国能够更好地发展算力和AI而牺牲了吧,我愿埋葬于那片曾经奋斗过的地方。
|
|
||||||
|
|
||||||
诺亚,再见
|
|
||||||
|
|
||||||
2025年7月6日凌晨 写于深圳
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
各位好,
|
|
||||||
|
|
||||||
感谢大家的关心与祝福。我目前暂时安全,但公司应该在进行排查与某些名单收集,后续情况未知。
|
|
||||||
|
|
||||||
我补充一些细节,以免某些人继续颠倒黑白。
|
|
||||||
|
|
||||||
关于135B V2,小模型实验室在迅速地完成套壳并拿完所有套壳带来的好处后(比如任务令表彰和及时激励),因为不想继续支撑下游应用和模型迭代,又把这个烫手山芋甩给了四纵。确实技高一筹,直接把四纵的兄弟们拉下水。同事提供过去一个老旧的模型,最终拿回了一个当时一个魔改的先进的千问。做大模型的人,自己做的模型就像自己孩子一样熟悉,不要把别人都当傻子。就像自家儿子出门一趟,回来个别人家孩子。
|
|
||||||
|
|
||||||
盘古report的署名是不符合学术规范的。例如,135B V3有不少有技术贡献的人,因为作者名额数量限制,劳动成果没有得到应有的回报,团队内曾经有不小的意见。这个模型当时是大家智慧和汗水的结晶,甚至是团队当时的精神支柱,支撑着不少兄弟们继续留在诺亚。所谓的名额限制,以及挂名了一些毫无技术贡献的人(如一些小模型实验室的人),让兄弟们何其心寒。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
暂时平安。另外,支持我勇于说出真相的战友们 https://github.com/HW-whistleblower/True-Story-of-Pangu/issues/317
|
|
||||||
@@ -1,193 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import asyncio
|
|
||||||
import dotenv
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
# Auto-detect user's mail path
|
|
||||||
def get_mail_path():
|
|
||||||
"""Get the mail path for the current user"""
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
return os.path.join(home_dir, "Library", "Mail")
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_sources(messages_dirs: List[Path], index_path: str = "mail_index.leann", max_count: int = -1, include_html: bool = False, embedding_model: str = "facebook/contriever"):
|
|
||||||
"""
|
|
||||||
Create LEANN index from multiple mail data sources.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages_dirs: List of Path objects pointing to Messages directories
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of emails to process per directory
|
|
||||||
include_html: Whether to include HTML content in email processing
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from multiple mail data sources...")
|
|
||||||
|
|
||||||
# Load documents using EmlxReader from local readers module
|
|
||||||
from .readers import EmlxReader, find_all_messages_directories
|
|
||||||
reader = EmlxReader(include_html=include_html)
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
# Process each Messages directory
|
|
||||||
for i, messages_dir in enumerate(messages_dirs):
|
|
||||||
print(f"\nProcessing Messages directory {i+1}/{len(messages_dirs)}: {messages_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
documents = reader.load_data(messages_dir)
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} email documents from {messages_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
|
|
||||||
# Check if we've reached the max count
|
|
||||||
if max_count > 0 and total_processed >= max_count:
|
|
||||||
print(f"Reached max count of {max_count} documents")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {messages_dir}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {messages_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} email documents from {len(messages_dirs)} directories")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1 # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} email chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
|
||||||
"""
|
|
||||||
Query the LEANN index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to the LEANN index
|
|
||||||
query: The query string
|
|
||||||
"""
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=index_path,
|
|
||||||
llm_config={"type": "openai", "model": "gpt-4o"})
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
|
||||||
chat_response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=10,
|
|
||||||
recompute_beighbor_embeddings=True,
|
|
||||||
complexity=12,
|
|
||||||
beam_width=1,
|
|
||||||
|
|
||||||
)
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"Time taken: {end_time - start_time} seconds")
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# Parse command line arguments
|
|
||||||
parser = argparse.ArgumentParser(description='LEANN Mail Reader - Create and query email index')
|
|
||||||
parser.add_argument('--index-dir', type=str, default="./mail_index_leann_raw_text_all_dicts",
|
|
||||||
help='Directory to store the LEANN index (default: ./mail_index_leann_raw_text_all_dicts)')
|
|
||||||
parser.add_argument('--max-emails', type=int, default=1000,
|
|
||||||
help='Maximum number of emails to process (-1 means all)')
|
|
||||||
parser.add_argument('--query', type=str, default="Give me some funny advertisement about apple or other companies",
|
|
||||||
help='Single query to run (default: runs example queries)')
|
|
||||||
parser.add_argument('--include-html', action='store_true', default=False,
|
|
||||||
help='Include HTML content in email processing (default: False)')
|
|
||||||
parser.add_argument('--embedding-model', type=str, default="facebook/contriever",
|
|
||||||
help='Embedding model to use (default: facebook/contriever)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f"args: {args}")
|
|
||||||
|
|
||||||
# Automatically find all Messages directories under the current user's Mail directory
|
|
||||||
from .readers import find_all_messages_directories
|
|
||||||
mail_path = get_mail_path()
|
|
||||||
print(f"Searching for email data in: {mail_path}")
|
|
||||||
messages_dirs = find_all_messages_directories(mail_path)
|
|
||||||
|
|
||||||
print('len(messages_dirs): ', len(messages_dirs))
|
|
||||||
|
|
||||||
if not messages_dirs:
|
|
||||||
print("No Messages directories found. Exiting.")
|
|
||||||
return
|
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "mail_documents.leann")
|
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
|
||||||
print(f"Found {len(messages_dirs)} Messages directories.")
|
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
|
||||||
index_path = create_leann_index_from_multiple_sources(messages_dirs, INDEX_PATH, args.max_emails, args.include_html, args.embedding_model)
|
|
||||||
|
|
||||||
if index_path:
|
|
||||||
if args.query:
|
|
||||||
# Run single query
|
|
||||||
await query_leann_index(index_path, args.query)
|
|
||||||
else:
|
|
||||||
# Example queries
|
|
||||||
queries = [
|
|
||||||
"Hows Berkeley Graduate Student Instructor",
|
|
||||||
"how's the icloud related advertisement saying",
|
|
||||||
"Whats the number of class recommend to take per semester for incoming EECS students"
|
|
||||||
]
|
|
||||||
for query in queries:
|
|
||||||
print("\n" + "="*60)
|
|
||||||
await query_leann_index(index_path, query)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
import os
|
|
||||||
import email
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
def find_all_messages_directories(root: str = None) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Recursively find all 'Messages' directories under the given root.
|
|
||||||
Returns a list of Path objects.
|
|
||||||
"""
|
|
||||||
if root is None:
|
|
||||||
# Auto-detect user's mail path
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
root = os.path.join(home_dir, "Library", "Mail")
|
|
||||||
|
|
||||||
messages_dirs = []
|
|
||||||
for dirpath, dirnames, filenames in os.walk(root):
|
|
||||||
if os.path.basename(dirpath) == "Messages":
|
|
||||||
messages_dirs.append(Path(dirpath))
|
|
||||||
return messages_dirs
|
|
||||||
|
|
||||||
class EmlxReader(BaseReader):
|
|
||||||
"""
|
|
||||||
Apple Mail .emlx file reader with embedded metadata.
|
|
||||||
|
|
||||||
Reads individual .emlx files from Apple Mail's storage format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, include_html: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
include_html: Whether to include HTML content in the email body (default: False)
|
|
||||||
"""
|
|
||||||
self.include_html = include_html
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Load data from the input directory containing .emlx files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing .emlx files
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum amount of messages to read.
|
|
||||||
"""
|
|
||||||
docs: List[Document] = []
|
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
# Walk through the directory recursively
|
|
||||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
|
||||||
# Skip hidden directories
|
|
||||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
|
||||||
|
|
||||||
for filename in filenames:
|
|
||||||
if count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
if filename.endswith(".emlx"):
|
|
||||||
filepath = os.path.join(dirpath, filename)
|
|
||||||
try:
|
|
||||||
# Read the .emlx file
|
|
||||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# .emlx files have a length prefix followed by the email content
|
|
||||||
# The first line contains the length, followed by the email
|
|
||||||
lines = content.split('\n', 1)
|
|
||||||
if len(lines) >= 2:
|
|
||||||
email_content = lines[1]
|
|
||||||
|
|
||||||
# Parse the email using Python's email module
|
|
||||||
try:
|
|
||||||
msg = email.message_from_string(email_content)
|
|
||||||
|
|
||||||
# Extract email metadata
|
|
||||||
subject = msg.get('Subject', 'No Subject')
|
|
||||||
from_addr = msg.get('From', 'Unknown')
|
|
||||||
to_addr = msg.get('To', 'Unknown')
|
|
||||||
date = msg.get('Date', 'Unknown')
|
|
||||||
|
|
||||||
# Extract email body
|
|
||||||
body = ""
|
|
||||||
if msg.is_multipart():
|
|
||||||
for part in msg.walk():
|
|
||||||
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
|
||||||
if part.get_content_type() == "text/html" and not self.include_html:
|
|
||||||
continue
|
|
||||||
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
|
||||||
# break
|
|
||||||
else:
|
|
||||||
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
|
||||||
doc_content = f"""
|
|
||||||
[EMAIL METADATA]
|
|
||||||
File: {filename}
|
|
||||||
From: {from_addr}
|
|
||||||
To: {to_addr}
|
|
||||||
Subject: {subject}
|
|
||||||
Date: {date}
|
|
||||||
[END METADATA]
|
|
||||||
|
|
||||||
{body}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# No separate metadata - everything is in the text
|
|
||||||
doc = Document(text=doc_content, metadata={})
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error parsing email from {filepath}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading file {filepath}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Loaded {len(docs)} email documents")
|
|
||||||
return docs
|
|
||||||
167
apps/email_data/LEANN_email_reader.py
Normal file
167
apps/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Recursively find all 'Messages' directories under the given root.
|
||||||
|
Returns a list of Path objects.
|
||||||
|
"""
|
||||||
|
if root is None:
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
messages_dirs = []
|
||||||
|
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||||
|
if os.path.basename(dirpath) == "Messages":
|
||||||
|
messages_dirs.append(Path(dirpath))
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
|
"""
|
||||||
|
self.include_html = include_html
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
total_files = 0
|
||||||
|
successful_files = 0
|
||||||
|
failed_files = 0
|
||||||
|
|
||||||
|
print(f"Starting to process directory: {input_dir}")
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
# Check if we've reached the max count (skip if max_count == -1)
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
total_files += 1
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/plain"
|
||||||
|
or part.get_content_type() == "text/html"
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/html"
|
||||||
|
and not self.include_html
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = part.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body += payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding payload: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
payload = msg.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body = payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding single part payload: {e}")
|
||||||
|
body = ""
|
||||||
|
|
||||||
|
# Only create document if we have some content
|
||||||
|
if body.strip() or subject != "No Subject":
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[File]: {filename}
|
||||||
|
[From]: {from_addr}
|
||||||
|
[To]: {to_addr}
|
||||||
|
[Subject]: {subject}
|
||||||
|
[Date]: {date}
|
||||||
|
[EMAIL BODY Start]:
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
successful_files += 1
|
||||||
|
|
||||||
|
# Print first few successful files for debugging
|
||||||
|
if successful_files <= 3:
|
||||||
|
print(
|
||||||
|
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Processing summary:")
|
||||||
|
print(f" Total .emlx files found: {total_files}")
|
||||||
|
print(f" Successfully loaded: {successful_files}")
|
||||||
|
print(f" Failed to load: {failed_files}")
|
||||||
|
print(f" Final documents: {len(docs)}")
|
||||||
|
|
||||||
|
return docs
|
||||||
@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
from fsspec import AbstractFileSystem
|
|
||||||
|
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.core.schema import Document
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
@@ -27,11 +27,7 @@ class MboxReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MESSAGE_FORMAT: str = (
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
"Date: {_date}\n"
|
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||||
"From: {_from}\n"
|
|
||||||
"To: {_to}\n"
|
|
||||||
"Subject: {_subject}\n"
|
|
||||||
"Content: {_content}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -45,9 +41,7 @@ class MboxReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
from bs4 import BeautifulSoup # noqa
|
from bs4 import BeautifulSoup # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
@@ -56,9 +50,9 @@ class MboxReader(BaseReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
file: Path,
|
file: Path,
|
||||||
extra_info: Optional[Dict] = None,
|
extra_info: dict | None = None,
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Parse file into string."""
|
"""Parse file into string."""
|
||||||
# Import required libraries
|
# Import required libraries
|
||||||
import mailbox
|
import mailbox
|
||||||
@@ -74,7 +68,7 @@ class MboxReader(BaseReader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
results: List[str] = []
|
results: list[str] = []
|
||||||
# Load file using mailbox
|
# Load file using mailbox
|
||||||
bytes_parser = BytesParser(policy=default).parse
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
@@ -134,12 +128,12 @@ class EmlxMboxReader(MboxReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
directory: Path,
|
directory: Path,
|
||||||
extra_info: Optional[Dict] = None,
|
extra_info: dict | None = None,
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
import tempfile
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
if fs:
|
if fs:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -156,18 +150,18 @@ class EmlxMboxReader(MboxReader):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# Create a temporary mbox file
|
# Create a temporary mbox file
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||||
temp_mbox_path = temp_mbox.name
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
# Convert .emlx files to mbox format
|
# Convert .emlx files to mbox format
|
||||||
for emlx_file in emlx_files:
|
for emlx_file in emlx_files:
|
||||||
try:
|
try:
|
||||||
# Read the .emlx file
|
# Read the .emlx file
|
||||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# .emlx format: first line is length, rest is email content
|
# .emlx format: first line is length, rest is email content
|
||||||
lines = content.split('\n', 1)
|
lines = content.split("\n", 1)
|
||||||
if len(lines) >= 2:
|
if len(lines) >= 2:
|
||||||
email_content = lines[1] # Skip the length line
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
@@ -188,5 +182,5 @@ class EmlxMboxReader(MboxReader):
|
|||||||
# Clean up temporary file
|
# Clean up temporary file
|
||||||
try:
|
try:
|
||||||
os.unlink(temp_mbox_path)
|
os.unlink(temp_mbox_path)
|
||||||
except:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
157
apps/email_rag.py
Normal file
157
apps/email_rag.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
Email RAG example using the unified interface.
|
||||||
|
Supports Apple Mail on macOS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Apple Mail processing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Process all emails by default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Email",
|
||||||
|
description="Process and query Apple Mail emails with LEANN",
|
||||||
|
default_index_name="mail_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add email-specific arguments."""
|
||||||
|
email_group = parser.add_argument_group("Email Parameters")
|
||||||
|
email_group.add_argument(
|
||||||
|
"--mail-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_mail_directories(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Apple Mail directories."""
|
||||||
|
mail_base = Path.home() / "Library" / "Mail"
|
||||||
|
if not mail_base.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find all Messages directories
|
||||||
|
messages_dirs = []
|
||||||
|
for item in mail_base.rglob("Messages"):
|
||||||
|
if item.is_dir():
|
||||||
|
messages_dirs.append(item)
|
||||||
|
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load emails and convert to text chunks."""
|
||||||
|
# Determine mail directories
|
||||||
|
if args.mail_path:
|
||||||
|
messages_dirs = [Path(args.mail_path)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Apple Mail directories...")
|
||||||
|
messages_dirs = self._find_mail_directories()
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Apple Mail directories found!")
|
||||||
|
print("Please specify --mail-path manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(messages_dirs)} mail directories")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = EmlxReader(include_html=args.include_html)
|
||||||
|
|
||||||
|
# Process each directory
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Count emlx files
|
||||||
|
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||||
|
print(f"Found {len(emlx_files)} email files")
|
||||||
|
|
||||||
|
# Apply max_items limit per directory
|
||||||
|
max_per_dir = -1 # Default to process all
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_dir = remaining
|
||||||
|
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||||
|
|
||||||
|
# Load emails - fix the parameter passing
|
||||||
|
documents = reader.load_data(
|
||||||
|
input_dir=str(messages_dir),
|
||||||
|
max_count=max_per_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} emails from this directory")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No emails found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
# Email reader uses chunk_overlap=25 as in original
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||||
|
print(" Windows/Linux support coming soon!\n")
|
||||||
|
|
||||||
|
# Example queries for email RAG
|
||||||
|
print("\n📧 Email RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What did my boss say about deadlines?'")
|
||||||
|
print("- 'Find emails about travel expenses'")
|
||||||
|
print("- 'Show me emails from last month about the project'")
|
||||||
|
print("- 'What food did I order from DoorDash?'")
|
||||||
|
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||||
|
|
||||||
|
rag = EmailRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
@@ -1,382 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
This script runs a recall evaluation on a given LEANN index.
|
|
||||||
It correctly compares results by fetching the text content for both the new search
|
|
||||||
results and the golden standard results, making the comparison robust to ID changes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from leann.api import LeannSearcher, LeannBuilder
|
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
|
||||||
if not data_root.exists():
|
|
||||||
print(f"Data directory '{data_root}' not found.")
|
|
||||||
print(
|
|
||||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
if download_embeddings:
|
|
||||||
# Download everything including embeddings (large files)
|
|
||||||
snapshot_download(
|
|
||||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=data_root,
|
|
||||||
local_dir_use_symlinks=False,
|
|
||||||
)
|
|
||||||
print("Data download complete (including embeddings)!")
|
|
||||||
else:
|
|
||||||
# Download only specific folders, excluding embeddings
|
|
||||||
allow_patterns = [
|
|
||||||
"ground_truth/**",
|
|
||||||
"indices/**",
|
|
||||||
"queries/**",
|
|
||||||
"*.md",
|
|
||||||
"*.txt",
|
|
||||||
]
|
|
||||||
snapshot_download(
|
|
||||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=data_root,
|
|
||||||
local_dir_use_symlinks=False,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
)
|
|
||||||
print("Data download complete (excluding embeddings)!")
|
|
||||||
except ImportError:
|
|
||||||
print(
|
|
||||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
|
||||||
)
|
|
||||||
print("uv pip install -e '.[dev]'")
|
|
||||||
sys.exit(1)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred during data download: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
|
||||||
"""Download embeddings files specifically."""
|
|
||||||
embeddings_dir = data_root / "embeddings"
|
|
||||||
|
|
||||||
if dataset_type:
|
|
||||||
# Check if specific dataset embeddings exist
|
|
||||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
|
||||||
if target_file.exists():
|
|
||||||
print(f"Embeddings for {dataset_type} already exist")
|
|
||||||
return str(target_file)
|
|
||||||
|
|
||||||
print("Downloading embeddings from HuggingFace Hub...")
|
|
||||||
try:
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
# Download only embeddings folder
|
|
||||||
snapshot_download(
|
|
||||||
repo_id="LEANN-RAG/leann-rag-evaluation-data",
|
|
||||||
repo_type="dataset",
|
|
||||||
local_dir=data_root,
|
|
||||||
local_dir_use_symlinks=False,
|
|
||||||
allow_patterns=["embeddings/**/*.pkl"],
|
|
||||||
)
|
|
||||||
print("Embeddings download complete!")
|
|
||||||
|
|
||||||
if dataset_type:
|
|
||||||
target_file = embeddings_dir / dataset_type / "passages_00.pkl"
|
|
||||||
if target_file.exists():
|
|
||||||
return str(target_file)
|
|
||||||
|
|
||||||
return str(embeddings_dir)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error downloading embeddings: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Helper Function to get Golden Passages ---
|
|
||||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
|
||||||
"""
|
|
||||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
|
||||||
passage manager.
|
|
||||||
"""
|
|
||||||
golden_texts = set()
|
|
||||||
for gid in golden_ids:
|
|
||||||
try:
|
|
||||||
# PassageManager uses string IDs
|
|
||||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
|
||||||
golden_texts.add(passage_data["text"])
|
|
||||||
except KeyError:
|
|
||||||
print(
|
|
||||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
|
||||||
)
|
|
||||||
return golden_texts
|
|
||||||
|
|
||||||
|
|
||||||
def load_queries(file_path: Path) -> List[str]:
|
|
||||||
queries = []
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
data = json.loads(line)
|
|
||||||
queries.append(data["query"])
|
|
||||||
return queries
|
|
||||||
|
|
||||||
|
|
||||||
def build_index_from_embeddings(
|
|
||||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Build a LEANN index from pre-computed embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embeddings_file: Path to pickle file with (ids, embeddings) tuple
|
|
||||||
output_path: Path where to save the index
|
|
||||||
backend: Backend to use ("hnsw" or "diskann")
|
|
||||||
"""
|
|
||||||
print(f"Building {backend} index from embeddings: {embeddings_file}")
|
|
||||||
|
|
||||||
# Create builder with appropriate parameters
|
|
||||||
if backend == "hnsw":
|
|
||||||
builder_kwargs = {
|
|
||||||
"M": 32, # Graph degree
|
|
||||||
"efConstruction": 256, # Construction complexity
|
|
||||||
"is_compact": True, # Use compact storage
|
|
||||||
"is_recompute": True, # Enable pruning for better recall
|
|
||||||
}
|
|
||||||
elif backend == "diskann":
|
|
||||||
builder_kwargs = {
|
|
||||||
"complexity": 64,
|
|
||||||
"graph_degree": 32,
|
|
||||||
"search_memory_maximum": 8.0, # GB
|
|
||||||
"build_memory_maximum": 16.0, # GB
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
builder_kwargs = {}
|
|
||||||
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name=backend,
|
|
||||||
embedding_model="facebook/contriever-msmarco", # Model used to create embeddings
|
|
||||||
dimensions=768, # Will be auto-detected from embeddings
|
|
||||||
**builder_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build index from precomputed embeddings
|
|
||||||
builder.build_index_from_embeddings(output_path, embeddings_file)
|
|
||||||
print(f"Index saved to: {output_path}")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Run recall evaluation on a LEANN index."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"index_path",
|
|
||||||
type=str,
|
|
||||||
nargs="?",
|
|
||||||
help="Path to the LEANN index to evaluate or build (optional).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--mode",
|
|
||||||
choices=["evaluate", "build"],
|
|
||||||
default="evaluate",
|
|
||||||
help="Mode: 'evaluate' existing index or 'build' from embeddings",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--embeddings-file",
|
|
||||||
type=str,
|
|
||||||
help="Path to embeddings pickle file (optional for build mode)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--backend",
|
|
||||||
choices=["hnsw", "diskann"],
|
|
||||||
default="hnsw",
|
|
||||||
help="Backend to use for building index (default: hnsw)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# --- Path Configuration ---
|
|
||||||
# Assumes a project structure where the script is in 'examples/'
|
|
||||||
# and data is in 'data/' at the project root.
|
|
||||||
project_root = Path(__file__).resolve().parent.parent
|
|
||||||
data_root = project_root / "data"
|
|
||||||
|
|
||||||
# Download data based on mode
|
|
||||||
if args.mode == "build":
|
|
||||||
# For building mode, we need embeddings
|
|
||||||
download_data_if_needed(
|
|
||||||
data_root, download_embeddings=False
|
|
||||||
) # Basic data first
|
|
||||||
|
|
||||||
# Auto-detect dataset type and download embeddings
|
|
||||||
if args.embeddings_file:
|
|
||||||
embeddings_file = args.embeddings_file
|
|
||||||
# Try to detect dataset type from embeddings file path
|
|
||||||
if "rpj_wiki" in str(embeddings_file):
|
|
||||||
dataset_type = "rpj_wiki"
|
|
||||||
elif "dpr" in str(embeddings_file):
|
|
||||||
dataset_type = "dpr"
|
|
||||||
else:
|
|
||||||
dataset_type = "dpr" # Default
|
|
||||||
else:
|
|
||||||
# Auto-detect from index path if provided, otherwise default to DPR
|
|
||||||
if args.index_path:
|
|
||||||
index_path_str = str(args.index_path)
|
|
||||||
if "rpj_wiki" in index_path_str:
|
|
||||||
dataset_type = "rpj_wiki"
|
|
||||||
elif "dpr" in index_path_str:
|
|
||||||
dataset_type = "dpr"
|
|
||||||
else:
|
|
||||||
dataset_type = "dpr" # Default to DPR
|
|
||||||
else:
|
|
||||||
dataset_type = "dpr" # Default to DPR
|
|
||||||
|
|
||||||
embeddings_file = download_embeddings_if_needed(data_root, dataset_type)
|
|
||||||
|
|
||||||
# Auto-generate index path if not provided
|
|
||||||
if not args.index_path:
|
|
||||||
indices_dir = data_root / "indices" / dataset_type
|
|
||||||
indices_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
args.index_path = str(indices_dir / f"{dataset_type}_from_embeddings")
|
|
||||||
print(f"Auto-generated index path: {args.index_path}")
|
|
||||||
|
|
||||||
print(f"Building index from embeddings: {embeddings_file}")
|
|
||||||
built_index_path = build_index_from_embeddings(
|
|
||||||
embeddings_file, args.index_path, args.backend
|
|
||||||
)
|
|
||||||
print(f"Index built successfully: {built_index_path}")
|
|
||||||
|
|
||||||
# Ask if user wants to run evaluation
|
|
||||||
eval_response = (
|
|
||||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
|
||||||
)
|
|
||||||
if eval_response != "y":
|
|
||||||
print("Index building complete. Exiting.")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# For evaluation mode, don't need embeddings
|
|
||||||
download_data_if_needed(data_root, download_embeddings=False)
|
|
||||||
|
|
||||||
# Auto-detect index path if not provided
|
|
||||||
if not args.index_path:
|
|
||||||
# Default to using downloaded indices
|
|
||||||
indices_dir = data_root / "indices"
|
|
||||||
|
|
||||||
# Try common datasets in order of preference
|
|
||||||
for dataset in ["dpr", "rpj_wiki"]:
|
|
||||||
dataset_dir = indices_dir / dataset
|
|
||||||
if dataset_dir.exists():
|
|
||||||
# Look for index files
|
|
||||||
index_files = list(dataset_dir.glob("*.index")) + list(
|
|
||||||
dataset_dir.glob("*_disk.index")
|
|
||||||
)
|
|
||||||
if index_files:
|
|
||||||
args.index_path = str(
|
|
||||||
index_files[0].with_suffix("")
|
|
||||||
) # Remove .index extension
|
|
||||||
print(f"Using index: {args.index_path}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if not args.index_path:
|
|
||||||
print(
|
|
||||||
"No indices found. The data download should have included pre-built indices."
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"Please check the data/indices/ directory or provide --index-path manually."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Detect dataset type from index path to select the correct ground truth
|
|
||||||
index_path_str = str(args.index_path)
|
|
||||||
if "rpj_wiki" in index_path_str:
|
|
||||||
dataset_type = "rpj_wiki"
|
|
||||||
elif "dpr" in index_path_str:
|
|
||||||
dataset_type = "dpr"
|
|
||||||
else:
|
|
||||||
# Fallback: try to infer from the index directory name
|
|
||||||
dataset_type = Path(args.index_path).name
|
|
||||||
print(
|
|
||||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
|
||||||
golden_results_file = (
|
|
||||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
|
||||||
print(f"INFO: Using queries file: {queries_file}")
|
|
||||||
print(f"INFO: Using ground truth file: {golden_results_file}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
searcher = LeannSearcher(args.index_path)
|
|
||||||
queries = load_queries(queries_file)
|
|
||||||
|
|
||||||
with open(golden_results_file, "r") as f:
|
|
||||||
golden_results_data = json.load(f)
|
|
||||||
|
|
||||||
num_eval_queries = min(args.num_queries, len(queries))
|
|
||||||
queries = queries[:num_eval_queries]
|
|
||||||
|
|
||||||
print(f"\nRunning evaluation on {num_eval_queries} queries...")
|
|
||||||
recall_scores = []
|
|
||||||
search_times = []
|
|
||||||
|
|
||||||
for i in range(num_eval_queries):
|
|
||||||
start_time = time.time()
|
|
||||||
new_results = searcher.search(
|
|
||||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
|
||||||
)
|
|
||||||
search_times.append(time.time() - start_time)
|
|
||||||
|
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
|
||||||
new_texts = {result.text for result in new_results}
|
|
||||||
|
|
||||||
# Get golden texts directly from the searcher's passage manager
|
|
||||||
golden_ids = golden_results_data["indices"][i][: args.top_k]
|
|
||||||
golden_texts = get_golden_texts(searcher, golden_ids)
|
|
||||||
|
|
||||||
overlap = len(new_texts & golden_texts)
|
|
||||||
recall = overlap / len(golden_texts) if golden_texts else 0
|
|
||||||
recall_scores.append(recall)
|
|
||||||
|
|
||||||
print("\n--- EVALUATION RESULTS ---")
|
|
||||||
print(f"Query: {queries[i]}")
|
|
||||||
print(f"New Results: {new_texts}")
|
|
||||||
print(f"Golden Results: {golden_texts}")
|
|
||||||
print(f"Overlap: {overlap}")
|
|
||||||
print(f"Recall: {recall}")
|
|
||||||
print(f"Search Time: {search_times[-1]:.4f}s")
|
|
||||||
print("--------------------------------")
|
|
||||||
|
|
||||||
avg_recall = np.mean(recall_scores) if recall_scores else 0
|
|
||||||
avg_time = np.mean(search_times) if search_times else 0
|
|
||||||
|
|
||||||
print("\n🎉 --- Evaluation Complete ---")
|
|
||||||
print(f"Avg. Recall@{args.top_k} (efSearch={args.ef_search}): {avg_recall:.4f}")
|
|
||||||
print(f"Avg. Search Time: {avg_time:.4f}s")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ An error occurred during evaluation: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
from .history import ChromeHistoryReader
|
from .history import ChromeHistoryReader
|
||||||
|
|
||||||
__all__ = ['ChromeHistoryReader']
|
__all__ = ["ChromeHistoryReader"]
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
import sqlite3
|
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
class ChromeHistoryReader(BaseReader):
|
class ChromeHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
@@ -17,7 +19,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load Chrome history data from the default Chrome profile location.
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
@@ -27,13 +29,15 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
max_count (int): Maximum amount of history entries to read.
|
max_count (int): Maximum amount of history entries to read.
|
||||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||||
|
|
||||||
# Default Chrome profile path on macOS
|
# Default Chrome profile path on macOS
|
||||||
if chrome_profile_path is None:
|
if chrome_profile_path is None:
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
|
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
@@ -70,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
@@ -82,7 +86,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
||||||
# if len(title) > 150:
|
# if len(title) > 150:
|
||||||
# print(f"Title is too long: {title}")
|
# print(f"Title is too long: {title}")
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
@@ -93,12 +97,17 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading Chrome history: {e}")
|
print(f"Error reading Chrome history: {e}")
|
||||||
|
# add you may need to close your browser to make the database file available
|
||||||
|
# also highlight in red
|
||||||
|
print(
|
||||||
|
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||||
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_chrome_profiles() -> List[Path]:
|
def find_chrome_profiles() -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find all Chrome profile directories.
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
@@ -124,7 +133,9 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
return profile_dirs
|
return profile_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
def export_history_to_file(
|
||||||
|
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export Chrome history to a text file using the same SQL query format.
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
@@ -132,7 +143,9 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
output_file: Path to the output file
|
output_file: Path to the output file
|
||||||
max_count: Maximum number of entries to export
|
max_count: Maximum number of entries to export
|
||||||
"""
|
"""
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
if not os.path.exists(history_db_path):
|
if not os.path.exists(history_db_path):
|
||||||
@@ -159,10 +172,12 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
cursor.execute(query, (max_count,))
|
cursor.execute(query, (max_count,))
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
f.write(
|
||||||
|
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||||
|
)
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
@@ -2,13 +2,14 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
class WeChatHistoryReader(BaseReader):
|
class WeChatHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
@@ -43,10 +44,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
if not wechattweak_path.exists():
|
if not wechattweak_path.exists():
|
||||||
print("Downloading WeChatTweak CLI...")
|
print("Downloading WeChatTweak CLI...")
|
||||||
subprocess.run([
|
subprocess.run(
|
||||||
"curl", "-L", "-o", str(wechattweak_path),
|
[
|
||||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
"curl",
|
||||||
], check=True)
|
"-L",
|
||||||
|
"-o",
|
||||||
|
str(wechattweak_path),
|
||||||
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Make executable
|
# Make executable
|
||||||
wechattweak_path.chmod(0o755)
|
wechattweak_path.chmod(0o755)
|
||||||
@@ -73,16 +80,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
def check_api_available(self) -> bool:
|
def check_api_available(self) -> bool:
|
||||||
"""Check if WeChatTweak API is available."""
|
"""Check if WeChatTweak API is available."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
||||||
], capture_output=True, text=True, timeout=5)
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
return result.returncode == 0 and result.stdout.strip()
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_readable_text(self, content: str) -> str:
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Extract readable text from message content, removing XML and system messages.
|
Extract readable text from message content, removing XML and system messages.
|
||||||
@@ -100,14 +107,14 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Extract text from dictionary structure
|
# Extract text from dictionary structure
|
||||||
text_parts = []
|
text_parts = []
|
||||||
if 'title' in content:
|
if "title" in content:
|
||||||
text_parts.append(str(content['title']))
|
text_parts.append(str(content["title"]))
|
||||||
if 'quoted' in content:
|
if "quoted" in content:
|
||||||
text_parts.append(str(content['quoted']))
|
text_parts.append(str(content["quoted"]))
|
||||||
if 'content' in content:
|
if "content" in content:
|
||||||
text_parts.append(str(content['content']))
|
text_parts.append(str(content["content"]))
|
||||||
if 'text' in content:
|
if "text" in content:
|
||||||
text_parts.append(str(content['text']))
|
text_parts.append(str(content["text"]))
|
||||||
|
|
||||||
if text_parts:
|
if text_parts:
|
||||||
return " | ".join(text_parts)
|
return " | ".join(text_parts)
|
||||||
@@ -120,11 +127,11 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Remove common prefixes like "wxid_xxx:\n"
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
# If it's just XML or system message, return empty
|
# If it's just XML or system message, return empty
|
||||||
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return clean_content.strip()
|
return clean_content.strip()
|
||||||
@@ -145,9 +152,9 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
# Handle dictionary content
|
# Handle dictionary content
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Check if dict has any readable text fields
|
# Check if dict has any readable text fields
|
||||||
text_fields = ['title', 'quoted', 'content', 'text']
|
text_fields = ["title", "quoted", "content", "text"]
|
||||||
for field in text_fields:
|
for field in text_fields:
|
||||||
if field in content and content[field]:
|
if content.get(field):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -156,42 +163,47 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip image messages (contain XML with img tags)
|
# Skip image messages (contain XML with img tags)
|
||||||
if '<img' in content and 'cdnurl' in content:
|
if "<img" in content and "cdnurl" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip emoji messages (contain emoji XML tags)
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
if '<emoji' in content and 'productid' in content:
|
if "<emoji" in content and "productid" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip voice messages
|
# Skip voice messages
|
||||||
if '<voice' in content:
|
if "<voice" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip video messages
|
# Skip video messages
|
||||||
if '<video' in content:
|
if "<video" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip file messages
|
# Skip file messages
|
||||||
if '<appmsg' in content and 'appid' in content:
|
if "<appmsg" in content and "appid" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip system messages (like "recalled a message")
|
# Skip system messages (like "recalled a message")
|
||||||
if 'recalled a message' in content:
|
if "recalled a message" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if there's actual readable text (not just XML or system messages)
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
# If after cleaning we have meaningful text, consider it readable
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
def _concatenate_messages(
|
||||||
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
max_length: int = 128,
|
||||||
|
time_window_minutes: int = 30,
|
||||||
|
overlap_messages: int = 0,
|
||||||
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Concatenate messages based on length and time rules.
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
@@ -214,12 +226,12 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
# Extract message info
|
# Extract message info
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
from_user = message.get('fromUser', '')
|
message.get("fromUser", "")
|
||||||
to_user = message.get('toUser', '')
|
message.get("toUser", "")
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
@@ -236,16 +248,24 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if time_diff_minutes > time_window_minutes:
|
if time_diff_minutes > time_window_minutes:
|
||||||
# Time gap too large, start new group
|
# Time gap too large, start new group
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
@@ -254,16 +274,24 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
message_length = len(readable_text)
|
message_length = len(readable_text)
|
||||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
# Current group would exceed max length, save it and start new
|
# Current group would exceed max length, save it and start new
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
@@ -275,16 +303,18 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
# Add the last group if it exists
|
# Add the last group if it exists
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return concatenated_groups
|
return concatenated_groups
|
||||||
|
|
||||||
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Create concatenated content from a group of messages.
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
@@ -295,16 +325,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
Returns:
|
Returns:
|
||||||
Formatted concatenated content
|
Formatted concatenated content
|
||||||
"""
|
"""
|
||||||
messages = message_group['messages']
|
messages = message_group["messages"]
|
||||||
start_time = message_group['start_time']
|
start_time = message_group["start_time"]
|
||||||
end_time = message_group['end_time']
|
end_time = message_group["end_time"]
|
||||||
|
|
||||||
# Format timestamps
|
# Format timestamps
|
||||||
if start_time:
|
if start_time:
|
||||||
try:
|
try:
|
||||||
start_timestamp = datetime.fromtimestamp(start_time)
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
start_time_str = str(start_time)
|
start_time_str = str(start_time)
|
||||||
else:
|
else:
|
||||||
start_time_str = "Unknown"
|
start_time_str = "Unknown"
|
||||||
@@ -312,8 +342,8 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if end_time:
|
if end_time:
|
||||||
try:
|
try:
|
||||||
end_timestamp = datetime.fromtimestamp(end_time)
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
end_time_str = str(end_time)
|
end_time_str = str(end_time)
|
||||||
else:
|
else:
|
||||||
end_time_str = "Unknown"
|
end_time_str = "Unknown"
|
||||||
@@ -321,10 +351,10 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
# Build concatenated message content
|
# Build concatenated message content
|
||||||
message_parts = []
|
message_parts = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
@@ -336,8 +366,8 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
# change to YYYY-MM-DD HH:MM:SS
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -351,7 +381,7 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
Contact: {contact_name}
|
Contact: {contact_name}
|
||||||
Time Range: {start_time_str} - {end_time_str}
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||||
|
|
||||||
{concatenated_text}
|
{concatenated_text}
|
||||||
"""
|
"""
|
||||||
@@ -361,7 +391,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
"""
|
"""
|
||||||
return doc_content, contact_name
|
return doc_content, contact_name
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load WeChat chat history data from exported JSON files.
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
@@ -376,13 +406,13 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||||
include_non_text = load_kwargs.get('include_non_text', False)
|
include_non_text = load_kwargs.get("include_non_text", False)
|
||||||
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||||
max_length = load_kwargs.get('max_length', 1000)
|
max_length = load_kwargs.get("max_length", 1000)
|
||||||
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||||
|
|
||||||
# Default WeChat export path
|
# Default WeChat export path
|
||||||
if wechat_export_dir is None:
|
if wechat_export_dir is None:
|
||||||
@@ -403,7 +433,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, 'r', encoding='utf-8') as f:
|
with open(json_file, encoding="utf-8") as f:
|
||||||
chat_data = json.load(f)
|
chat_data = json.load(f)
|
||||||
|
|
||||||
# Extract contact name from filename
|
# Extract contact name from filename
|
||||||
@@ -414,7 +444,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
readable_messages = []
|
readable_messages = []
|
||||||
for message in chat_data:
|
for message in chat_data:
|
||||||
try:
|
try:
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
if not include_non_text and not self._is_text_message(content):
|
if not include_non_text and not self._is_text_message(content):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -430,9 +460,9 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
# Concatenate messages based on rules
|
# Concatenate messages based on rules
|
||||||
message_groups = self._concatenate_messages(
|
message_groups = self._concatenate_messages(
|
||||||
readable_messages,
|
readable_messages,
|
||||||
max_length=-1,
|
max_length=max_length,
|
||||||
time_window_minutes=-1,
|
time_window_minutes=time_window_minutes,
|
||||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
overlap_messages=0, # No overlap between groups
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create documents from concatenated groups
|
# Create documents from concatenated groups
|
||||||
@@ -440,12 +470,19 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
doc_content, contact_name = self._create_concatenated_content(
|
||||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
message_group, contact_name
|
||||||
|
)
|
||||||
|
doc = Document(
|
||||||
|
text=doc_content,
|
||||||
|
metadata={"contact_name": contact_name},
|
||||||
|
)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
print(
|
||||||
|
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Original single-message processing
|
# Original single-message processing
|
||||||
@@ -454,12 +491,12 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Extract message information
|
# Extract message information
|
||||||
from_user = message.get('fromUser', '')
|
message.get("fromUser", "")
|
||||||
to_user = message.get('toUser', '')
|
message.get("toUser", "")
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Handle content that might be dict or string
|
# Handle content that might be dict or string
|
||||||
try:
|
try:
|
||||||
@@ -480,8 +517,8 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -495,7 +532,9 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={})
|
doc = Document(
|
||||||
|
text=doc_content, metadata={"contact_name": contact_name}
|
||||||
|
)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@@ -512,7 +551,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_wechat_export_dirs() -> List[Path]:
|
def find_wechat_export_dirs() -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find all WeChat export directories.
|
Find all WeChat export directories.
|
||||||
|
|
||||||
@@ -523,10 +562,10 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
|
|
||||||
# Look for common export directory names
|
# Look for common export directory names
|
||||||
possible_dirs = [
|
possible_dirs = [
|
||||||
Path("./wechat_export_test"),
|
|
||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export")
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir in possible_dirs:
|
for export_dir in possible_dirs:
|
||||||
@@ -534,13 +573,20 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
json_files = list(export_dir.glob("*.json"))
|
json_files = list(export_dir.glob("*.json"))
|
||||||
if json_files:
|
if json_files:
|
||||||
export_dirs.append(export_dir)
|
export_dirs.append(export_dir)
|
||||||
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
print(
|
||||||
|
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
return export_dirs
|
return export_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
def export_chat_to_file(
|
||||||
|
output_file: str = "wechat_chat_export.txt",
|
||||||
|
max_count: int = 1000,
|
||||||
|
export_dir: str | None = None,
|
||||||
|
include_non_text: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history to a text file.
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
@@ -560,14 +606,14 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
try:
|
try:
|
||||||
json_files = list(Path(export_dir).glob("*.json"))
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
count = 0
|
count = 0
|
||||||
for json_file in json_files:
|
for json_file in json_files:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, 'r', encoding='utf-8') as json_f:
|
with open(json_file, encoding="utf-8") as json_f:
|
||||||
chat_data = json.load(json_f)
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
contact_name = json_file.stem
|
contact_name = json_file.stem
|
||||||
@@ -577,10 +623,10 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
from_user = message.get('fromUser', '')
|
from_user = message.get("fromUser", "")
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
|
|
||||||
# Skip non-text messages unless requested
|
# Skip non-text messages unless requested
|
||||||
if not include_non_text:
|
if not include_non_text:
|
||||||
@@ -595,8 +641,8 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -613,7 +659,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error exporting WeChat chat history: {e}")
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history using wechat-exporter tool.
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
@@ -642,16 +688,21 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
if requirements_file.exists():
|
if requirements_file.exists():
|
||||||
print("Installing wechat-exporter requirements...")
|
print("Installing wechat-exporter requirements...")
|
||||||
subprocess.run([
|
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
||||||
"uv", "pip", "install", "-r", str(requirements_file)
|
|
||||||
], check=True)
|
|
||||||
|
|
||||||
# Run the export command
|
# Run the export command
|
||||||
print("Running wechat-exporter...")
|
print("Running wechat-exporter...")
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
[
|
||||||
"export-all", str(export_path)
|
sys.executable,
|
||||||
], capture_output=True, text=True, check=True)
|
str(self.wechat_exporter_dir / "main.py"),
|
||||||
|
"export-all",
|
||||||
|
str(export_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
print("Export command output:")
|
print("Export command output:")
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
@@ -662,7 +713,9 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
# Check if export was successful
|
# Check if export was successful
|
||||||
if export_path.exists() and any(export_path.glob("*.json")):
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
json_files = list(export_path.glob("*.json"))
|
json_files = list(export_path.glob("*.json"))
|
||||||
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
print(
|
||||||
|
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
||||||
|
)
|
||||||
return export_path
|
return export_path
|
||||||
else:
|
else:
|
||||||
print("Export completed but no JSON files found")
|
print("Export completed but no JSON files found")
|
||||||
@@ -678,7 +731,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find existing WeChat exports or create new ones.
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
@@ -697,7 +750,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
Path("./wechat_export_direct"),
|
Path("./wechat_export_direct"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export")
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir_path in possible_export_dirs:
|
for export_dir_path in possible_export_dirs:
|
||||||
@@ -714,6 +767,8 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if exported_path:
|
if exported_path:
|
||||||
export_dirs = [exported_path]
|
export_dirs = [exported_path]
|
||||||
else:
|
else:
|
||||||
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
print(
|
||||||
|
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
||||||
|
)
|
||||||
|
|
||||||
return export_dirs
|
return export_dirs
|
||||||
@@ -0,0 +1,182 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_backend_hnsw.hnsw_backend import HNSWBuilder, HNSWSearcher # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class LeannMultiVector:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
index_path: str,
|
||||||
|
dim: int = 128,
|
||||||
|
distance_metric: str = "mips",
|
||||||
|
m: int = 16,
|
||||||
|
ef_construction: int = 500,
|
||||||
|
is_compact: bool = False,
|
||||||
|
is_recompute: bool = False,
|
||||||
|
embedding_model_name: str = "colvision",
|
||||||
|
) -> None:
|
||||||
|
self.index_path = index_path
|
||||||
|
self.dim = dim
|
||||||
|
self.embedding_model_name = embedding_model_name
|
||||||
|
self._pending_items: list[dict] = []
|
||||||
|
self._backend_kwargs = {
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"M": m,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
"is_compact": is_compact,
|
||||||
|
"is_recompute": is_recompute,
|
||||||
|
}
|
||||||
|
self._labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
def _meta_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"version": "1.0",
|
||||||
|
"backend_name": "hnsw",
|
||||||
|
"embedding_model": self.embedding_model_name,
|
||||||
|
"embedding_mode": "custom",
|
||||||
|
"dimensions": self.dim,
|
||||||
|
"backend_kwargs": self._backend_kwargs,
|
||||||
|
"is_compact": self._backend_kwargs.get("is_compact", True),
|
||||||
|
"is_pruned": self._backend_kwargs.get("is_compact", True)
|
||||||
|
and self._backend_kwargs.get("is_recompute", True),
|
||||||
|
}
|
||||||
|
|
||||||
|
def create_collection(self) -> None:
|
||||||
|
path = Path(self.index_path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def insert(self, data: dict) -> None:
|
||||||
|
self._pending_items.append(
|
||||||
|
{
|
||||||
|
"doc_id": int(data["doc_id"]),
|
||||||
|
"filepath": data.get("filepath", ""),
|
||||||
|
"colbert_vecs": [np.asarray(v, dtype=np.float32) for v in data["colbert_vecs"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _labels_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.labels.json"
|
||||||
|
|
||||||
|
def _meta_path(self) -> Path:
|
||||||
|
index_path_obj = Path(self.index_path)
|
||||||
|
return index_path_obj.parent / f"{index_path_obj.name}.meta.json"
|
||||||
|
|
||||||
|
def create_index(self) -> None:
|
||||||
|
if not self._pending_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
embeddings: list[np.ndarray] = []
|
||||||
|
labels_meta: list[dict] = []
|
||||||
|
|
||||||
|
for item in self._pending_items:
|
||||||
|
doc_id = int(item["doc_id"])
|
||||||
|
filepath = item.get("filepath", "")
|
||||||
|
colbert_vecs = item["colbert_vecs"]
|
||||||
|
for seq_id, vec in enumerate(colbert_vecs):
|
||||||
|
vec_np = np.asarray(vec, dtype=np.float32)
|
||||||
|
embeddings.append(vec_np)
|
||||||
|
labels_meta.append(
|
||||||
|
{
|
||||||
|
"id": f"{doc_id}:{seq_id}",
|
||||||
|
"doc_id": doc_id,
|
||||||
|
"seq_id": int(seq_id),
|
||||||
|
"filepath": filepath,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
embeddings_np = np.vstack(embeddings).astype(np.float32)
|
||||||
|
# print shape of embeddings_np
|
||||||
|
print(embeddings_np.shape)
|
||||||
|
|
||||||
|
builder = HNSWBuilder(**{**self._backend_kwargs, "dimensions": self.dim})
|
||||||
|
ids = [str(i) for i in range(embeddings_np.shape[0])]
|
||||||
|
builder.build(embeddings_np, ids, self.index_path)
|
||||||
|
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
with open(self._meta_path(), "w", encoding="utf-8") as f:
|
||||||
|
_json.dump(self._meta_dict(), f, indent=2)
|
||||||
|
with open(self._labels_path(), "w", encoding="utf-8") as f:
|
||||||
|
_json.dump(labels_meta, f)
|
||||||
|
|
||||||
|
self._labels_meta = labels_meta
|
||||||
|
|
||||||
|
def _load_labels_meta_if_needed(self) -> None:
|
||||||
|
if self._labels_meta:
|
||||||
|
return
|
||||||
|
labels_path = self._labels_path()
|
||||||
|
if labels_path.exists():
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
with open(labels_path, encoding="utf-8") as f:
|
||||||
|
self._labels_meta = _json.load(f)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, data: np.ndarray, topk: int, first_stage_k: int = 50
|
||||||
|
) -> list[tuple[float, int]]:
|
||||||
|
if data.ndim == 1:
|
||||||
|
data = data.reshape(1, -1)
|
||||||
|
if data.dtype != np.float32:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
self._load_labels_meta_if_needed()
|
||||||
|
|
||||||
|
searcher = HNSWSearcher(self.index_path, meta=self._meta_dict())
|
||||||
|
raw = searcher.search(
|
||||||
|
data,
|
||||||
|
first_stage_k,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
complexity=128,
|
||||||
|
beam_width=1,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
batch_size=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
labels = raw.get("labels")
|
||||||
|
distances = raw.get("distances")
|
||||||
|
if labels is None or distances is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
doc_scores: dict[int, float] = {}
|
||||||
|
B = len(labels)
|
||||||
|
for b in range(B):
|
||||||
|
per_doc_best: dict[int, float] = {}
|
||||||
|
for k, sid in enumerate(labels[b]):
|
||||||
|
try:
|
||||||
|
idx = int(sid)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if 0 <= idx < len(self._labels_meta):
|
||||||
|
doc_id = int(self._labels_meta[idx]["doc_id"]) # type: ignore[index]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
score = float(distances[b][k])
|
||||||
|
if (doc_id not in per_doc_best) or (score > per_doc_best[doc_id]):
|
||||||
|
per_doc_best[doc_id] = score
|
||||||
|
for doc_id, best_score in per_doc_best.items():
|
||||||
|
doc_scores[doc_id] = doc_scores.get(doc_id, 0.0) + best_score
|
||||||
|
|
||||||
|
scores = sorted(((v, k) for k, v in doc_scores.items()), key=lambda x: x[0], reverse=True)
|
||||||
|
return scores[:topk] if len(scores) >= topk else scores
|
||||||
@@ -0,0 +1,477 @@
|
|||||||
|
## Jupyter-style notebook script
|
||||||
|
# %%
|
||||||
|
# uv pip install matplotlib qwen_vl_utils
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_repo_paths_importable(current_file: str) -> None:
|
||||||
|
"""Make local leann packages importable without installing (mirrors multi-vector-leann.py)."""
|
||||||
|
_repo_root = Path(current_file).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Config
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"
|
||||||
|
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||||
|
|
||||||
|
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||||
|
USE_HF_DATASET: bool = True
|
||||||
|
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||||
|
DATASET_SPLIT: str = "train"
|
||||||
|
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||||
|
|
||||||
|
# Local pages (used when USE_HF_DATASET == False)
|
||||||
|
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||||
|
PAGES_DIR: str = "./pages"
|
||||||
|
|
||||||
|
# Index + retrieval settings
|
||||||
|
INDEX_PATH: str = "./indexes/colvision.leann"
|
||||||
|
TOPK: int = 1
|
||||||
|
FIRST_STAGE_K: int = 500
|
||||||
|
REBUILD_INDEX: bool = False
|
||||||
|
|
||||||
|
# Artifacts
|
||||||
|
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||||
|
SIMILARITY_MAP: bool = True
|
||||||
|
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||||
|
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||||
|
ANSWER: bool = True
|
||||||
|
MAX_NEW_TOKENS: int = 128
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Helpers
|
||||||
|
def _natural_sort_key(name: str) -> int:
|
||||||
|
m = re.search(r"\d+", name)
|
||||||
|
return int(m.group()) if m else 0
|
||||||
|
|
||||||
|
|
||||||
|
def _load_images_from_dir(pages_dir: str) -> tuple[list[str], list[Image.Image]]:
|
||||||
|
filenames = [n for n in os.listdir(pages_dir) if n.lower().endswith((".png", ".jpg", ".jpeg"))]
|
||||||
|
filenames = sorted(filenames, key=_natural_sort_key)
|
||||||
|
filepaths = [os.path.join(pages_dir, n) for n in filenames]
|
||||||
|
images = [Image.open(p) for p in filepaths]
|
||||||
|
return filepaths, images
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_convert_pdf_to_images(pdf_path: Optional[str], pages_dir: str, dpi: int = 200) -> None:
|
||||||
|
if not pdf_path:
|
||||||
|
return
|
||||||
|
os.makedirs(pages_dir, exist_ok=True)
|
||||||
|
try:
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"pdf2image is required to convert PDF to images. Install via pip install pdf2image"
|
||||||
|
) from e
|
||||||
|
images = convert_from_path(pdf_path, dpi=dpi)
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(os.path.join(pages_dir, f"page_{i + 1}.png"), "PNG")
|
||||||
|
|
||||||
|
|
||||||
|
def _select_device_and_dtype():
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import get_torch_device
|
||||||
|
|
||||||
|
device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(device_str)
|
||||||
|
# Stable dtype selection to avoid NaNs:
|
||||||
|
# - CUDA: prefer bfloat16 if supported, else float16
|
||||||
|
# - MPS: use float32 (fp16 on MPS can produce NaNs in some ops)
|
||||||
|
# - CPU: float32
|
||||||
|
if device_str == "cuda":
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
try:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True # Better stability/perf on Ampere+
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif device_str == "mps":
|
||||||
|
dtype = torch.float32
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
return device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _load_colvision(model_choice: str):
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali, ColQwen2, ColQwen2Processor
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
device_str, device, dtype = _select_device_and_dtype()
|
||||||
|
|
||||||
|
if model_choice == "colqwen2":
|
||||||
|
model_name = "vidore/colqwen2-v1.0"
|
||||||
|
# On CPU/MPS we must avoid flash-attn and stay eager; on CUDA prefer flash-attn if available
|
||||||
|
attn_implementation = (
|
||||||
|
"flash_attention_2"
|
||||||
|
if (device_str == "cuda" and is_flash_attn_2_available())
|
||||||
|
else "eager"
|
||||||
|
)
|
||||||
|
model = ColQwen2.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
).eval()
|
||||||
|
processor = ColQwen2Processor.from_pretrained(model_name)
|
||||||
|
else:
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
return model_name, model, processor, device_str, device, dtype
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_images(model, processor, images: list[Image.Image]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Ensure deterministic eval and autocast for stability
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[Image.Image](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
doc_vecs: list[Any] = []
|
||||||
|
for batch_doc in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
# autocast on CUDA for bf16/fp16; on CPU/MPS stay in fp32
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
else:
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
doc_vecs.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
return doc_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_queries(model, processor, queries: list[str]) -> list[Any]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
q_vecs: list[Any] = []
|
||||||
|
for batch_query in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
if model.device.type == "cuda":
|
||||||
|
with torch.autocast(
|
||||||
|
device_type="cuda",
|
||||||
|
dtype=model.dtype if model.dtype.is_floating_point else torch.bfloat16,
|
||||||
|
):
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
else:
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
q_vecs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
return q_vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _build_index(index_path: str, doc_vecs: list[Any], filepaths: list[str]) -> LeannMultiVector:
|
||||||
|
dim = int(doc_vecs[0].shape[-1])
|
||||||
|
retriever = LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
retriever.create_collection()
|
||||||
|
for i, vec in enumerate(doc_vecs):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": vec.float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
def _load_retriever_if_index_exists(index_path: str, dim: int) -> Optional[LeannMultiVector]:
|
||||||
|
index_base = Path(index_path)
|
||||||
|
# Rough heuristic: index dir exists AND meta+labels files exist
|
||||||
|
meta = index_base.parent / f"{index_base.name}.meta.json"
|
||||||
|
labels = index_base.parent / f"{index_base.name}.labels.json"
|
||||||
|
if index_base.exists() and meta.exists() and labels.exists():
|
||||||
|
return LeannMultiVector(index_path=index_path, dim=dim)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_similarity_map(
|
||||||
|
model,
|
||||||
|
processor,
|
||||||
|
image: Image.Image,
|
||||||
|
query: str,
|
||||||
|
token_idx: Optional[int] = None,
|
||||||
|
output_path: Optional[str] = None,
|
||||||
|
) -> tuple[int, float]:
|
||||||
|
import torch
|
||||||
|
from colpali_engine.interpretability import (
|
||||||
|
get_similarity_maps_from_embeddings,
|
||||||
|
plot_similarity_map,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_images = processor.process_images([image]).to(model.device)
|
||||||
|
batch_queries = processor.process_queries([query]).to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
image_embeddings = model.forward(**batch_images)
|
||||||
|
query_embeddings = model.forward(**batch_queries)
|
||||||
|
|
||||||
|
n_patches = processor.get_n_patches(
|
||||||
|
image_size=image.size,
|
||||||
|
spatial_merge_size=getattr(model, "spatial_merge_size", None),
|
||||||
|
)
|
||||||
|
image_mask = processor.get_image_mask(batch_images)
|
||||||
|
|
||||||
|
batched_similarity_maps = get_similarity_maps_from_embeddings(
|
||||||
|
image_embeddings=image_embeddings,
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
n_patches=n_patches,
|
||||||
|
image_mask=image_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
similarity_maps = batched_similarity_maps[0]
|
||||||
|
|
||||||
|
# Determine token index if not provided: choose the token with highest max score
|
||||||
|
if token_idx is None:
|
||||||
|
per_token_max = similarity_maps.view(similarity_maps.shape[0], -1).max(dim=1).values
|
||||||
|
token_idx = int(per_token_max.argmax().item())
|
||||||
|
|
||||||
|
max_sim_score = similarity_maps[token_idx, :, :].max().item()
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, ax = plot_similarity_map(
|
||||||
|
image=image,
|
||||||
|
similarity_map=similarity_maps[token_idx],
|
||||||
|
figsize=(14, 14),
|
||||||
|
show_colorbar=False,
|
||||||
|
)
|
||||||
|
ax.set_title(f"Token #{token_idx}. MaxSim score: {max_sim_score:.2f}", fontsize=12)
|
||||||
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||||
|
plt.savefig(output_path, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return token_idx, float(max_sim_score)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenVL:
|
||||||
|
def __init__(self, device: str):
|
||||||
|
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "eager"
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map=device,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
|
||||||
|
min_pixels = 256 * 28 * 28
|
||||||
|
max_pixels = 1280 * 28 * 28
|
||||||
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
|
||||||
|
def answer(self, query: str, images: list[Image.Image], max_new_tokens: int = 128) -> str:
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
content = []
|
||||||
|
for img in images:
|
||||||
|
buffer = BytesIO()
|
||||||
|
img.save(buffer, format="jpeg")
|
||||||
|
img_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
content.append({"type": "image", "image": f"data:image;base64,{img_base64}"})
|
||||||
|
content.append({"type": "text", "text": query})
|
||||||
|
messages = [{"role": "user", "content": content}]
|
||||||
|
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
inputs = inputs.to(self.model.device)
|
||||||
|
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
return self.processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Step 1: Prepare data
|
||||||
|
if USE_HF_DATASET:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
||||||
|
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||||
|
filepaths: list[str] = []
|
||||||
|
images: list[Image.Image] = []
|
||||||
|
for i in tqdm(range(N), desc="Loading dataset"):
|
||||||
|
p = dataset[i]
|
||||||
|
# Compose a descriptive identifier for printing later
|
||||||
|
identifier = f"arXiv:{p['paper_arxiv_id']}|title:{p['paper_title']}|page:{int(p['page_number'])}|id:{p['page_id']}"
|
||||||
|
print(identifier)
|
||||||
|
filepaths.append(identifier)
|
||||||
|
images.append(p["page_image"]) # PIL Image
|
||||||
|
else:
|
||||||
|
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||||
|
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 2: Load model and processor
|
||||||
|
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||||
|
print(f"Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 3: Build or load index
|
||||||
|
retriever: Optional[LeannMultiVector] = None
|
||||||
|
if not REBUILD_INDEX:
|
||||||
|
try:
|
||||||
|
one_vec = _embed_images(model, processor, [images[0]])[0]
|
||||||
|
retriever = _load_retriever_if_index_exists(INDEX_PATH, dim=int(one_vec.shape[-1]))
|
||||||
|
except Exception:
|
||||||
|
retriever = None
|
||||||
|
|
||||||
|
if retriever is None:
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 4: Embed query and search
|
||||||
|
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||||
|
results = retriever.search(q_vec.float().numpy(), topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||||
|
if not results:
|
||||||
|
print("No results found.")
|
||||||
|
else:
|
||||||
|
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||||
|
top_images: list[Image.Image] = []
|
||||||
|
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||||
|
path = filepaths[doc_id]
|
||||||
|
# For HF dataset, path is a descriptive identifier, not a real file path
|
||||||
|
print(f"{rank}) MaxSim: {score:.4f}, Page: {path}")
|
||||||
|
top_images.append(images[doc_id])
|
||||||
|
|
||||||
|
if SAVE_TOP_IMAGE:
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
base = _Path(SAVE_TOP_IMAGE)
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if base.suffix:
|
||||||
|
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||||
|
else:
|
||||||
|
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||||
|
img.save(str(out_path))
|
||||||
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
|
## TODO stange results of second page of DeepSeek-V2 rather than the first page
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 5: Similarity maps for top-K results
|
||||||
|
if results and SIMILARITY_MAP:
|
||||||
|
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if output_base:
|
||||||
|
if output_base.suffix:
|
||||||
|
out_dir = output_base.parent
|
||||||
|
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||||
|
out_path = str(out_dir / out_name)
|
||||||
|
else:
|
||||||
|
out_dir = output_base
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||||
|
else:
|
||||||
|
out_path = None
|
||||||
|
chosen_idx, max_sim = _generate_similarity_map(
|
||||||
|
model=model,
|
||||||
|
processor=processor,
|
||||||
|
image=img,
|
||||||
|
query=QUERY,
|
||||||
|
token_idx=token_idx,
|
||||||
|
output_path=out_path,
|
||||||
|
)
|
||||||
|
if out_path:
|
||||||
|
print(
|
||||||
|
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 6: Optional answer generation
|
||||||
|
if results and ANSWER:
|
||||||
|
qwen = QwenVL(device=device_str)
|
||||||
|
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||||
|
print("\nAnswer:")
|
||||||
|
print(response)
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
# pip install pdf2image
|
||||||
|
# pip install pymilvus
|
||||||
|
# pip install colpali_engine
|
||||||
|
# pip install tqdm
|
||||||
|
# pip install pillow
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from pdf2image import convert_from_path
|
||||||
|
|
||||||
|
pdf_path = "pdfs/2004.12832v2.pdf"
|
||||||
|
images = convert_from_path(pdf_path)
|
||||||
|
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
image.save(f"pages/page_{i + 1}.png", "PNG")
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Make local leann packages importable without installing
|
||||||
|
_repo_root = Path(__file__).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
import sys
|
||||||
|
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
from leann_multi_vector import LeannMultiVector
|
||||||
|
|
||||||
|
|
||||||
|
class LeannRetriever(LeannMultiVector):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||||
|
_device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(_device_str)
|
||||||
|
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||||
|
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=_dtype,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"How to end-to-end retrieval with ColBert",
|
||||||
|
"Where is ColBERT performance Table, including text representation results?",
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
qs: list[torch.Tensor] = []
|
||||||
|
for batch_query in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
print(qs[0].shape)
|
||||||
|
# %%
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||||
|
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
ds: list[torch.Tensor] = []
|
||||||
|
for batch_doc in tqdm(dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
|
||||||
|
print(ds[0].shape)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Build HNSW index via LeannRetriever primitives and run search
|
||||||
|
index_path = "./indexes/colpali.leann"
|
||||||
|
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||||
|
retriever.create_collection()
|
||||||
|
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||||
|
for i in range(len(filepaths)):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": ds[i].float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
for query in qs:
|
||||||
|
query_np = query.float().numpy()
|
||||||
|
result = retriever.search(query_np, topk=1)
|
||||||
|
print(filepaths[result[0][1]])
|
||||||
@@ -1,230 +0,0 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import dotenv
|
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any, Optional
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
import requests
|
|
||||||
import time
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
|
|
||||||
# Default WeChat export directory
|
|
||||||
DEFAULT_WECHAT_EXPORT_DIR = "./wechat_export_direct"
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_wechat_exports(
|
|
||||||
export_dirs: List[Path],
|
|
||||||
index_path: str = "wechat_history_index.leann",
|
|
||||||
max_count: int = -1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create LEANN index from multiple WeChat export data sources.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_dirs: List of Path objects pointing to WeChat export directories
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of chat entries to process per export
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from multiple WeChat export data sources...")
|
|
||||||
|
|
||||||
# Load documents using WeChatHistoryReader from local readers module
|
|
||||||
from .readers import WeChatHistoryReader
|
|
||||||
|
|
||||||
reader = WeChatHistoryReader()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
# Process each WeChat export directory
|
|
||||||
for i, export_dir in enumerate(export_dirs):
|
|
||||||
print(
|
|
||||||
f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
documents = reader.load_data(
|
|
||||||
wechat_export_dir=str(export_dir),
|
|
||||||
max_count=max_count,
|
|
||||||
concatenate_messages=True, # Disable concatenation - one message per document
|
|
||||||
)
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
|
|
||||||
# Check if we've reached the max count
|
|
||||||
if max_count > 0 and total_processed >= max_count:
|
|
||||||
print(f"Reached max count of {max_count} documents")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {export_dir}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {export_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
text = '[Contact] means the message is from: ' + doc.metadata["contact_name"] + '\n' + node.get_content()
|
|
||||||
all_texts.append(text)
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Created {len(all_texts)} text chunks from {len(all_documents)} documents"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="Qwen/Qwen3-Embedding-0.6B",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1, # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} chat chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
|
||||||
"""
|
|
||||||
Query the LEANN index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to the LEANN index
|
|
||||||
query: The query string
|
|
||||||
"""
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=index_path)
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=20,
|
|
||||||
recompute_beighbor_embeddings=True,
|
|
||||||
complexity=16,
|
|
||||||
beam_width=1,
|
|
||||||
llm_config={
|
|
||||||
"type": "openai",
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
llm_kwargs={"temperature": 0.0, "max_tokens": 1000},
|
|
||||||
)
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Main function with integrated WeChat export functionality."""
|
|
||||||
|
|
||||||
# Parse command line arguments
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="LEANN WeChat History Reader - Create and query WeChat chat history index"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--export-dir",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_WECHAT_EXPORT_DIR,
|
|
||||||
help=f"Directory to store WeChat exports (default: {DEFAULT_WECHAT_EXPORT_DIR})",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--index-dir",
|
|
||||||
type=str,
|
|
||||||
default="./wechat_history_magic_test_11Debug_new",
|
|
||||||
help="Directory to store the LEANN index (default: ./wechat_history_index_leann_test)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-entries",
|
|
||||||
type=int,
|
|
||||||
default=50,
|
|
||||||
help="Maximum number of chat entries to process (default: 5000)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--query",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Single query to run (default: runs example queries)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--force-export",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Force re-export of WeChat data even if exports exist",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "wechat_history.leann")
|
|
||||||
|
|
||||||
print(f"Using WeChat export directory: {args.export_dir}")
|
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
|
||||||
print(f"Max entries: {args.max_entries}")
|
|
||||||
|
|
||||||
# Initialize WeChat reader with export capabilities
|
|
||||||
from .readers import WeChatHistoryReader
|
|
||||||
|
|
||||||
reader = WeChatHistoryReader()
|
|
||||||
|
|
||||||
# Find existing exports or create new ones using the centralized method
|
|
||||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
|
||||||
if not export_dirs:
|
|
||||||
print("Failed to find or export WeChat data. Exiting.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
|
||||||
index_path = create_leann_index_from_multiple_wechat_exports(
|
|
||||||
export_dirs, INDEX_PATH, max_count=args.max_entries
|
|
||||||
)
|
|
||||||
|
|
||||||
if index_path:
|
|
||||||
if args.query:
|
|
||||||
# Run single query
|
|
||||||
await query_leann_index(index_path, args.query)
|
|
||||||
else:
|
|
||||||
# Example queries
|
|
||||||
queries = [
|
|
||||||
"我想买魔术师约翰逊的球衣,给我一些对应聊天记录?",
|
|
||||||
]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
await query_leann_index(index_path, query)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,719 +0,0 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any, Dict, Optional
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
class WeChatHistoryReader(BaseReader):
|
|
||||||
"""
|
|
||||||
WeChat chat history reader that extracts chat data from exported JSON files.
|
|
||||||
|
|
||||||
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
|
||||||
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
|
||||||
|
|
||||||
Also includes utilities for automatic WeChat chat history export.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
"""Initialize."""
|
|
||||||
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
|
||||||
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
|
||||||
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
|
||||||
|
|
||||||
def check_wechat_running(self) -> bool:
|
|
||||||
"""Check if WeChat is currently running."""
|
|
||||||
try:
|
|
||||||
result = subprocess.run(["pgrep", "-f", "WeChat"], capture_output=True, text=True)
|
|
||||||
return result.returncode == 0
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def install_wechattweak(self) -> bool:
|
|
||||||
"""Install WeChatTweak CLI tool."""
|
|
||||||
try:
|
|
||||||
# Create wechat-exporter directory if it doesn't exist
|
|
||||||
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
|
||||||
if not wechattweak_path.exists():
|
|
||||||
print("Downloading WeChatTweak CLI...")
|
|
||||||
subprocess.run([
|
|
||||||
"curl", "-L", "-o", str(wechattweak_path),
|
|
||||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
|
||||||
], check=True)
|
|
||||||
|
|
||||||
# Make executable
|
|
||||||
wechattweak_path.chmod(0o755)
|
|
||||||
|
|
||||||
# Install WeChatTweak
|
|
||||||
print("Installing WeChatTweak...")
|
|
||||||
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error installing WeChatTweak: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def restart_wechat(self):
|
|
||||||
"""Restart WeChat to apply WeChatTweak."""
|
|
||||||
try:
|
|
||||||
print("Restarting WeChat...")
|
|
||||||
subprocess.run(["pkill", "-f", "WeChat"], check=False)
|
|
||||||
time.sleep(2)
|
|
||||||
subprocess.run(["open", "-a", "WeChat"], check=True)
|
|
||||||
time.sleep(5) # Wait for WeChat to start
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error restarting WeChat: {e}")
|
|
||||||
|
|
||||||
def check_api_available(self) -> bool:
|
|
||||||
"""Check if WeChatTweak API is available."""
|
|
||||||
try:
|
|
||||||
result = subprocess.run([
|
|
||||||
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
|
||||||
], capture_output=True, text=True, timeout=5)
|
|
||||||
return result.returncode == 0 and result.stdout.strip()
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_readable_text(self, content: str) -> str:
|
|
||||||
"""
|
|
||||||
Extract readable text from message content, removing XML and system messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The raw message content (can be string or dict)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cleaned, readable text
|
|
||||||
"""
|
|
||||||
if not content:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Handle dictionary content (like quoted messages)
|
|
||||||
if isinstance(content, dict):
|
|
||||||
# Extract text from dictionary structure
|
|
||||||
text_parts = []
|
|
||||||
if 'title' in content:
|
|
||||||
text_parts.append(str(content['title']))
|
|
||||||
if 'quoted' in content:
|
|
||||||
text_parts.append(str(content['quoted']))
|
|
||||||
if 'content' in content:
|
|
||||||
text_parts.append(str(content['content']))
|
|
||||||
if 'text' in content:
|
|
||||||
text_parts.append(str(content['text']))
|
|
||||||
|
|
||||||
if text_parts:
|
|
||||||
return " | ".join(text_parts)
|
|
||||||
else:
|
|
||||||
# If we can't extract meaningful text from dict, return empty
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Handle string content
|
|
||||||
if not isinstance(content, str):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
# Remove common prefixes like "wxid_xxx:\n"
|
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
|
||||||
|
|
||||||
# If it's just XML or system message, return empty
|
|
||||||
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
return clean_content.strip()
|
|
||||||
|
|
||||||
def _is_text_message(self, content: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a message contains readable text content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: The message content (can be string or dict)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the message contains readable text, False otherwise
|
|
||||||
"""
|
|
||||||
if not content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Handle dictionary content
|
|
||||||
if isinstance(content, dict):
|
|
||||||
# Check if dict has any readable text fields
|
|
||||||
text_fields = ['title', 'quoted', 'content', 'text']
|
|
||||||
for field in text_fields:
|
|
||||||
if field in content and content[field]:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Handle string content
|
|
||||||
if not isinstance(content, str):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Skip image messages (contain XML with img tags)
|
|
||||||
if '<img' in content and 'cdnurl' in content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Skip emoji messages (contain emoji XML tags)
|
|
||||||
if '<emoji' in content and 'productid' in content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Skip voice messages
|
|
||||||
if '<voice' in content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Skip video messages
|
|
||||||
if '<video' in content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Skip file messages
|
|
||||||
if '<appmsg' in content and 'appid' in content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Skip system messages (like "recalled a message")
|
|
||||||
if 'recalled a message' in content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if there's actual readable text (not just XML or system messages)
|
|
||||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
|
||||||
|
|
||||||
# If after cleaning we have meaningful text, consider it readable
|
|
||||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
|
||||||
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Concatenate messages based on length and time rules.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of message dictionaries
|
|
||||||
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
|
||||||
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
|
||||||
overlap_messages: Number of messages to overlap between consecutive groups
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of concatenated message groups
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return []
|
|
||||||
|
|
||||||
concatenated_groups = []
|
|
||||||
current_group = []
|
|
||||||
current_length = 0
|
|
||||||
last_timestamp = None
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
# Extract message info
|
|
||||||
content = message.get('content', '')
|
|
||||||
message_text = message.get('message', '')
|
|
||||||
create_time = message.get('createTime', 0)
|
|
||||||
from_user = message.get('fromUser', '')
|
|
||||||
to_user = message.get('toUser', '')
|
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
|
||||||
|
|
||||||
# Extract readable text
|
|
||||||
readable_text = self._extract_readable_text(content)
|
|
||||||
if not readable_text:
|
|
||||||
readable_text = message_text
|
|
||||||
|
|
||||||
# Skip empty messages
|
|
||||||
if not readable_text.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check time window constraint (only if time_window_minutes != -1)
|
|
||||||
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
|
||||||
time_diff_minutes = (create_time - last_timestamp) / 60
|
|
||||||
if time_diff_minutes > time_window_minutes:
|
|
||||||
# Time gap too large, start new group
|
|
||||||
if current_group:
|
|
||||||
concatenated_groups.append({
|
|
||||||
'messages': current_group,
|
|
||||||
'total_length': current_length,
|
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
|
||||||
})
|
|
||||||
# Keep last few messages for overlap
|
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
|
||||||
current_group = current_group[-overlap_messages:]
|
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
|
||||||
else:
|
|
||||||
current_group = []
|
|
||||||
current_length = 0
|
|
||||||
|
|
||||||
# Check length constraint (only if max_length != -1)
|
|
||||||
message_length = len(readable_text)
|
|
||||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
|
||||||
# Current group would exceed max length, save it and start new
|
|
||||||
concatenated_groups.append({
|
|
||||||
'messages': current_group,
|
|
||||||
'total_length': current_length,
|
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
|
||||||
})
|
|
||||||
# Keep last few messages for overlap
|
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
|
||||||
current_group = current_group[-overlap_messages:]
|
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
|
||||||
else:
|
|
||||||
current_group = []
|
|
||||||
current_length = 0
|
|
||||||
|
|
||||||
# Add message to current group
|
|
||||||
current_group.append(message)
|
|
||||||
current_length += message_length
|
|
||||||
last_timestamp = create_time
|
|
||||||
|
|
||||||
# Add the last group if it exists
|
|
||||||
if current_group:
|
|
||||||
concatenated_groups.append({
|
|
||||||
'messages': current_group,
|
|
||||||
'total_length': current_length,
|
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
|
||||||
})
|
|
||||||
|
|
||||||
return concatenated_groups
|
|
||||||
|
|
||||||
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Create concatenated content from a group of messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_group: Dictionary containing messages and metadata
|
|
||||||
contact_name: Name of the contact
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted concatenated content
|
|
||||||
"""
|
|
||||||
messages = message_group['messages']
|
|
||||||
start_time = message_group['start_time']
|
|
||||||
end_time = message_group['end_time']
|
|
||||||
|
|
||||||
# Format timestamps
|
|
||||||
if start_time:
|
|
||||||
try:
|
|
||||||
start_timestamp = datetime.fromtimestamp(start_time)
|
|
||||||
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
except:
|
|
||||||
start_time_str = str(start_time)
|
|
||||||
else:
|
|
||||||
start_time_str = "Unknown"
|
|
||||||
|
|
||||||
if end_time:
|
|
||||||
try:
|
|
||||||
end_timestamp = datetime.fromtimestamp(end_time)
|
|
||||||
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
except:
|
|
||||||
end_time_str = str(end_time)
|
|
||||||
else:
|
|
||||||
end_time_str = "Unknown"
|
|
||||||
|
|
||||||
# Build concatenated message content
|
|
||||||
message_parts = []
|
|
||||||
for message in messages:
|
|
||||||
content = message.get('content', '')
|
|
||||||
message_text = message.get('message', '')
|
|
||||||
create_time = message.get('createTime', 0)
|
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
|
||||||
|
|
||||||
# Extract readable text
|
|
||||||
readable_text = self._extract_readable_text(content)
|
|
||||||
if not readable_text:
|
|
||||||
readable_text = message_text
|
|
||||||
|
|
||||||
# Format individual message
|
|
||||||
if create_time:
|
|
||||||
try:
|
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
|
||||||
# change to YYYY-MM-DD HH:MM:SS
|
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
except:
|
|
||||||
time_str = str(create_time)
|
|
||||||
else:
|
|
||||||
time_str = "Unknown"
|
|
||||||
|
|
||||||
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
|
||||||
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
|
||||||
|
|
||||||
concatenated_text = "\n".join(message_parts)
|
|
||||||
|
|
||||||
# Create final document content
|
|
||||||
doc_content = f"""
|
|
||||||
Contact: {contact_name}
|
|
||||||
Time Range: {start_time_str} - {end_time_str}
|
|
||||||
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|
||||||
|
|
||||||
{concatenated_text}
|
|
||||||
"""
|
|
||||||
# TODO @yichuan give better format and rich info here!
|
|
||||||
doc_content = f"""
|
|
||||||
{concatenated_text}
|
|
||||||
"""
|
|
||||||
return doc_content, contact_name
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Load WeChat chat history data from exported JSON files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing exported WeChat JSON files
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum amount of chat entries to read.
|
|
||||||
wechat_export_dir (str): Custom path to WeChat export directory.
|
|
||||||
include_non_text (bool): Whether to include non-text messages (images, emojis, etc.)
|
|
||||||
concatenate_messages (bool): Whether to concatenate messages based on length rules.
|
|
||||||
max_length (int): Maximum length for concatenated message groups (default: 1000).
|
|
||||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
|
||||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
|
||||||
"""
|
|
||||||
docs: List[Document] = []
|
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
|
||||||
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
|
||||||
include_non_text = load_kwargs.get('include_non_text', False)
|
|
||||||
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
|
||||||
max_length = load_kwargs.get('max_length', 1000)
|
|
||||||
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
|
||||||
|
|
||||||
# Default WeChat export path
|
|
||||||
if wechat_export_dir is None:
|
|
||||||
wechat_export_dir = "./wechat_export_test"
|
|
||||||
|
|
||||||
if not os.path.exists(wechat_export_dir):
|
|
||||||
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Find all JSON files in the export directory
|
|
||||||
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
|
||||||
print(f"Found {len(json_files)} WeChat chat history files")
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for json_file in json_files:
|
|
||||||
if count >= max_count and max_count > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(json_file, 'r', encoding='utf-8') as f:
|
|
||||||
chat_data = json.load(f)
|
|
||||||
|
|
||||||
# Extract contact name from filename
|
|
||||||
contact_name = json_file.stem
|
|
||||||
|
|
||||||
if concatenate_messages:
|
|
||||||
# Filter messages to only include readable text messages
|
|
||||||
readable_messages = []
|
|
||||||
for message in chat_data:
|
|
||||||
try:
|
|
||||||
content = message.get('content', '')
|
|
||||||
if not include_non_text and not self._is_text_message(content):
|
|
||||||
continue
|
|
||||||
|
|
||||||
readable_text = self._extract_readable_text(content)
|
|
||||||
if not readable_text and not include_non_text:
|
|
||||||
continue
|
|
||||||
|
|
||||||
readable_messages.append(message)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing message in {json_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Concatenate messages based on rules
|
|
||||||
message_groups = self._concatenate_messages(
|
|
||||||
readable_messages,
|
|
||||||
max_length=-1,
|
|
||||||
time_window_minutes=-1,
|
|
||||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create documents from concatenated groups
|
|
||||||
for message_group in message_groups:
|
|
||||||
if count >= max_count and max_count > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
|
||||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Original single-message processing
|
|
||||||
for message in chat_data:
|
|
||||||
if count >= max_count and max_count > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Extract message information
|
|
||||||
from_user = message.get('fromUser', '')
|
|
||||||
to_user = message.get('toUser', '')
|
|
||||||
content = message.get('content', '')
|
|
||||||
message_text = message.get('message', '')
|
|
||||||
create_time = message.get('createTime', 0)
|
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
|
||||||
|
|
||||||
# Handle content that might be dict or string
|
|
||||||
try:
|
|
||||||
# Check if this is a readable text message
|
|
||||||
if not include_non_text and not self._is_text_message(content):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Extract readable text
|
|
||||||
readable_text = self._extract_readable_text(content)
|
|
||||||
if not readable_text and not include_non_text:
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
|
||||||
# Skip messages that cause processing errors
|
|
||||||
print(f"Error processing message in {json_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Convert timestamp to readable format
|
|
||||||
if create_time:
|
|
||||||
try:
|
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
except:
|
|
||||||
time_str = str(create_time)
|
|
||||||
else:
|
|
||||||
time_str = "Unknown"
|
|
||||||
|
|
||||||
# Create document content with metadata header and contact info
|
|
||||||
doc_content = f"""
|
|
||||||
Contact: {contact_name}
|
|
||||||
Is sent from self: {is_sent_from_self}
|
|
||||||
Time: {time_str}
|
|
||||||
Message: {readable_text if readable_text else message_text}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create document with embedded metadata
|
|
||||||
doc = Document(text=doc_content, metadata={})
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading {json_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Loaded {len(docs)} WeChat chat documents")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading WeChat history: {e}")
|
|
||||||
return docs
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def find_wechat_export_dirs() -> List[Path]:
|
|
||||||
"""
|
|
||||||
Find all WeChat export directories.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Path objects pointing to WeChat export directories
|
|
||||||
"""
|
|
||||||
export_dirs = []
|
|
||||||
|
|
||||||
# Look for common export directory names
|
|
||||||
possible_dirs = [
|
|
||||||
Path("./wechat_export_test"),
|
|
||||||
Path("./wechat_export"),
|
|
||||||
Path("./wechat_chat_history"),
|
|
||||||
Path("./chat_export")
|
|
||||||
]
|
|
||||||
|
|
||||||
for export_dir in possible_dirs:
|
|
||||||
if export_dir.exists() and export_dir.is_dir():
|
|
||||||
json_files = list(export_dir.glob("*.json"))
|
|
||||||
if json_files:
|
|
||||||
export_dirs.append(export_dir)
|
|
||||||
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
|
||||||
|
|
||||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
|
||||||
return export_dirs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
|
||||||
"""
|
|
||||||
Export WeChat chat history to a text file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
output_file: Path to the output file
|
|
||||||
max_count: Maximum number of entries to export
|
|
||||||
export_dir: Directory containing WeChat JSON files
|
|
||||||
include_non_text: Whether to include non-text messages
|
|
||||||
"""
|
|
||||||
if export_dir is None:
|
|
||||||
export_dir = "./wechat_export_test"
|
|
||||||
|
|
||||||
if not os.path.exists(export_dir):
|
|
||||||
print(f"WeChat export directory not found at: {export_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
json_files = list(Path(export_dir).glob("*.json"))
|
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
|
||||||
count = 0
|
|
||||||
for json_file in json_files:
|
|
||||||
if count >= max_count and max_count > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(json_file, 'r', encoding='utf-8') as json_f:
|
|
||||||
chat_data = json.load(json_f)
|
|
||||||
|
|
||||||
contact_name = json_file.stem
|
|
||||||
f.write(f"\n=== Chat with {contact_name} ===\n")
|
|
||||||
|
|
||||||
for message in chat_data:
|
|
||||||
if count >= max_count and max_count > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
from_user = message.get('fromUser', '')
|
|
||||||
content = message.get('content', '')
|
|
||||||
message_text = message.get('message', '')
|
|
||||||
create_time = message.get('createTime', 0)
|
|
||||||
|
|
||||||
# Skip non-text messages unless requested
|
|
||||||
if not include_non_text:
|
|
||||||
reader = WeChatHistoryReader()
|
|
||||||
if not reader._is_text_message(content):
|
|
||||||
continue
|
|
||||||
readable_text = reader._extract_readable_text(content)
|
|
||||||
if not readable_text:
|
|
||||||
continue
|
|
||||||
message_text = readable_text
|
|
||||||
|
|
||||||
if create_time:
|
|
||||||
try:
|
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
except:
|
|
||||||
time_str = str(create_time)
|
|
||||||
else:
|
|
||||||
time_str = "Unknown"
|
|
||||||
|
|
||||||
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {json_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Exported {count} chat entries to {output_file}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error exporting WeChat chat history: {e}")
|
|
||||||
|
|
||||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
|
||||||
"""
|
|
||||||
Export WeChat chat history using wechat-exporter tool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_dir: Directory to save exported chat history
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to export directory if successful, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Create export directory
|
|
||||||
export_path = Path(export_dir)
|
|
||||||
export_path.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"Exporting WeChat chat history to {export_path}...")
|
|
||||||
|
|
||||||
# Check if wechat-exporter directory exists
|
|
||||||
if not self.wechat_exporter_dir.exists():
|
|
||||||
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Install requirements if needed
|
|
||||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
|
||||||
if requirements_file.exists():
|
|
||||||
print("Installing wechat-exporter requirements...")
|
|
||||||
subprocess.run([
|
|
||||||
"uv", "pip", "install", "-r", str(requirements_file)
|
|
||||||
], check=True)
|
|
||||||
|
|
||||||
# Run the export command
|
|
||||||
print("Running wechat-exporter...")
|
|
||||||
result = subprocess.run([
|
|
||||||
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
|
||||||
"export-all", str(export_path)
|
|
||||||
], capture_output=True, text=True, check=True)
|
|
||||||
|
|
||||||
print("Export command output:")
|
|
||||||
print(result.stdout)
|
|
||||||
if result.stderr:
|
|
||||||
print("Export errors:")
|
|
||||||
print(result.stderr)
|
|
||||||
|
|
||||||
# Check if export was successful
|
|
||||||
if export_path.exists() and any(export_path.glob("*.json")):
|
|
||||||
json_files = list(export_path.glob("*.json"))
|
|
||||||
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
|
||||||
return export_path
|
|
||||||
else:
|
|
||||||
print("Export completed but no JSON files found")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(f"Export command failed: {e}")
|
|
||||||
print(f"Command output: {e.stdout}")
|
|
||||||
print(f"Command errors: {e.stderr}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Export failed: {e}")
|
|
||||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
|
||||||
"""
|
|
||||||
Find existing WeChat exports or create new ones.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
export_dir: Directory to save exported chat history if needed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Path objects pointing to WeChat export directories
|
|
||||||
"""
|
|
||||||
export_dirs = []
|
|
||||||
|
|
||||||
# Look for existing exports in common locations
|
|
||||||
possible_export_dirs = [
|
|
||||||
Path("./wechat_database_export"),
|
|
||||||
Path("./wechat_export_test"),
|
|
||||||
Path("./wechat_export"),
|
|
||||||
Path("./wechat_export_direct"),
|
|
||||||
Path("./wechat_chat_history"),
|
|
||||||
Path("./chat_export")
|
|
||||||
]
|
|
||||||
|
|
||||||
for export_dir_path in possible_export_dirs:
|
|
||||||
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
|
||||||
export_dirs.append(export_dir_path)
|
|
||||||
print(f"Found existing export: {export_dir_path}")
|
|
||||||
|
|
||||||
# If no existing exports, try to export automatically
|
|
||||||
if not export_dirs:
|
|
||||||
print("No existing WeChat exports found. Starting direct export...")
|
|
||||||
|
|
||||||
# Try to export using wechat-exporter
|
|
||||||
exported_path = self.export_wechat_chat_history(export_dir)
|
|
||||||
if exported_path:
|
|
||||||
export_dirs = [exported_path]
|
|
||||||
else:
|
|
||||||
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
|
||||||
|
|
||||||
return export_dirs
|
|
||||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
WeChat History RAG example using the unified interface.
|
||||||
|
Supports WeChat chat history export and search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
|
||||||
|
from .history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatRAG(BaseRAGExample):
|
||||||
|
"""RAG example for WeChat chat history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Match original default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="WeChat History",
|
||||||
|
description="Process and query WeChat chat history with LEANN",
|
||||||
|
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add WeChat-specific arguments."""
|
||||||
|
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_export",
|
||||||
|
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||||
|
"""Export WeChat data using wechattweak-cli."""
|
||||||
|
print("Exporting WeChat data...")
|
||||||
|
|
||||||
|
# Check if WeChat is running
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print("WeChat is not running. Please start WeChat first.")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
pass # pgrep might not be available on all systems
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Run export command
|
||||||
|
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Running: {' '.join(cmd)}")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print("WeChat data exported successfully!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"Export failed: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("\nError: wechattweak-cli not found!")
|
||||||
|
print("Please install it first:")
|
||||||
|
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load WeChat history and convert to text chunks."""
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||||
|
# Try to find any existing exports in common locations
|
||||||
|
export_dirs = reader.find_wechat_export_dirs()
|
||||||
|
if not export_dirs:
|
||||||
|
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Load documents from all found export directories
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per export
|
||||||
|
max_per_export = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_export = remaining
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_per_export,
|
||||||
|
concatenate_messages=True, # Enable message concatenation for better context
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks with contact information
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
text_splitter = SentenceSplitter(
|
||||||
|
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
# Add contact information to each chunk
|
||||||
|
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||||
|
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||||
|
print(" You can still query existing exports on other platforms\n")
|
||||||
|
|
||||||
|
# Example queries for WeChat RAG
|
||||||
|
print("\n💬 WeChat History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'Show me conversations about travel plans'")
|
||||||
|
print("- 'Find group chats about weekend activities'")
|
||||||
|
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||||
|
print("- 'What did we discuss about the project last month?'")
|
||||||
|
print("\nNote: WeChat must be running for export to work\n")
|
||||||
|
|
||||||
|
rag = WeChatRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
@@ -1,9 +1,24 @@
|
|||||||
# 🧪 Leann Sanity Checks
|
# 🧪 LEANN Benchmarks & Testing
|
||||||
|
|
||||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||||
|
|
||||||
## 📁 Test Files
|
## 📁 Test Files
|
||||||
|
|
||||||
|
### `diskann_vs_hnsw_speed_comparison.py`
|
||||||
|
Performance comparison between DiskANN and HNSW backends:
|
||||||
|
- ✅ **Search latency** comparison with both backends using recompute
|
||||||
|
- ✅ **Index size** and **build time** measurements
|
||||||
|
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||||
|
- ✅ **Configurable dataset sizes** for different scales
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Quick comparison with 500 docs, 10 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||||
|
|
||||||
|
# Large-scale comparison with 2000 docs, 20 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||||
|
```
|
||||||
|
|
||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
@@ -1,29 +1,32 @@
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from mlx_lm import load
|
from mlx_lm import load
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||||
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||||
NUM_RUNS = 10 # Number of runs to average for each batch size
|
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||||
|
|
||||||
# --- Generate Dummy Data ---
|
# --- Generate Dummy Data ---
|
||||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||||
|
|
||||||
# --- Benchmark Functions ---b
|
# --- Benchmark Functions ---b
|
||||||
|
|
||||||
|
|
||||||
def benchmark_torch(model, sentences):
|
def benchmark_torch(model, sentences):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
model.encode(sentences, convert_to_numpy=True)
|
model.encode(sentences, convert_to_numpy=True)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
return (end_time - start_time) * 1000 # Return time in ms
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
def benchmark_mlx(model, tokenizer, sentences):
|
def benchmark_mlx(model, tokenizer, sentences):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -63,6 +66,7 @@ def benchmark_mlx(model, tokenizer, sentences):
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
return (end_time - start_time) * 1000 # Return time in ms
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
# --- Main Execution ---
|
# --- Main Execution ---
|
||||||
def main():
|
def main():
|
||||||
print("--- Initializing Models ---")
|
print("--- Initializing Models ---")
|
||||||
@@ -98,7 +102,9 @@ def main():
|
|||||||
results_torch.append(np.mean(torch_times))
|
results_torch.append(np.mean(torch_times))
|
||||||
|
|
||||||
# Benchmark MLX
|
# Benchmark MLX
|
||||||
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
|
mlx_times = [
|
||||||
|
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||||
|
]
|
||||||
results_mlx.append(np.mean(mlx_times))
|
results_mlx.append(np.mean(mlx_times))
|
||||||
|
|
||||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||||
@@ -109,10 +115,16 @@ def main():
|
|||||||
# --- Plotting ---
|
# --- Plotting ---
|
||||||
print("\n--- Generating Plot ---")
|
print("\n--- Generating Plot ---")
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
|
plt.plot(
|
||||||
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
|
BATCH_SIZES,
|
||||||
|
results_torch,
|
||||||
|
marker="o",
|
||||||
|
linestyle="-",
|
||||||
|
label=f"PyTorch ({device})",
|
||||||
|
)
|
||||||
|
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||||
|
|
||||||
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
|
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||||
plt.xlabel("Batch Size")
|
plt.xlabel("Batch Size")
|
||||||
plt.ylabel("Average Time per Batch (ms)")
|
plt.ylabel("Average Time per Batch (ms)")
|
||||||
plt.xticks(BATCH_SIZES)
|
plt.xticks(BATCH_SIZES)
|
||||||
@@ -124,5 +136,6 @@ def main():
|
|||||||
plt.savefig(output_filename)
|
plt.savefig(output_filename)
|
||||||
print(f"Plot saved to {output_filename}")
|
print(f"Plot saved to {output_filename}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
148
benchmarks/benchmark_no_recompute.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def _meta_exists(index_path: str) -> bool:
|
||||||
|
p = Path(index_path)
|
||||||
|
return (p.parent / f"{p.stem}.meta.json").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
||||||
|
# if _meta_exists(index_path):
|
||||||
|
# return
|
||||||
|
kwargs = {}
|
||||||
|
if backend_name == "hnsw":
|
||||||
|
kwargs["is_compact"] = is_recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
||||||
|
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
num_threads=4,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for i in range(num_docs):
|
||||||
|
builder.add_text(
|
||||||
|
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
||||||
|
)
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _bench_group(
|
||||||
|
index_path: str,
|
||||||
|
recompute: bool,
|
||||||
|
query: str,
|
||||||
|
repeats: int,
|
||||||
|
complexity: int = 32,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> float:
|
||||||
|
# Independent searcher per group; fixed port when recompute
|
||||||
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
|
|
||||||
|
# Warm-up once
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _once() -> float:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
return time.time() - t0
|
||||||
|
|
||||||
|
if repeats <= 1:
|
||||||
|
t = _once()
|
||||||
|
else:
|
||||||
|
vals = [_once() for _ in range(repeats)]
|
||||||
|
vals.sort()
|
||||||
|
t = vals[len(vals) // 2]
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-docs", type=int, default=5000)
|
||||||
|
parser.add_argument("--repeats", type=int, default=3)
|
||||||
|
parser.add_argument("--complexity", type=int, default=32)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# ---------- Build HNSW variants ----------
|
||||||
|
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
||||||
|
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
||||||
|
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
||||||
|
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Build DiskANN variants ----------
|
||||||
|
diskann_r = str(base / "diskann_r.leann")
|
||||||
|
diskann_nr = str(base / "diskann_nr.leann")
|
||||||
|
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
||||||
|
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Helpers ----------
|
||||||
|
def _size_for(prefix: str) -> int:
|
||||||
|
p = Path(prefix)
|
||||||
|
base_dir = p.parent
|
||||||
|
stem = p.stem
|
||||||
|
total = 0
|
||||||
|
for f in base_dir.iterdir():
|
||||||
|
if f.is_file() and f.name.startswith(stem):
|
||||||
|
total += f.stat().st_size
|
||||||
|
return total
|
||||||
|
|
||||||
|
# ---------- HNSW benchmark ----------
|
||||||
|
t_hnsw_r = _bench_group(
|
||||||
|
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_hnsw_nr = _bench_group(
|
||||||
|
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
size_hnsw_r = _size_for(hnsw_r)
|
||||||
|
size_hnsw_nr = _size_for(hnsw_nr)
|
||||||
|
|
||||||
|
print("Benchmark results (HNSW):")
|
||||||
|
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(
|
||||||
|
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
||||||
|
)
|
||||||
|
print(" Expectation: no-recompute should be faster but larger on disk.")
|
||||||
|
|
||||||
|
# ---------- DiskANN benchmark ----------
|
||||||
|
t_diskann_r = _bench_group(
|
||||||
|
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_diskann_nr = _bench_group(
|
||||||
|
diskann_nr,
|
||||||
|
False,
|
||||||
|
"DiskANN NR test doc 123",
|
||||||
|
repeats=args.repeats,
|
||||||
|
complexity=args.complexity,
|
||||||
|
)
|
||||||
|
size_diskann_r = _size_for(diskann_r)
|
||||||
|
size_diskann_nr = _size_for(diskann_nr)
|
||||||
|
|
||||||
|
print("\nBenchmark results (DiskANN):")
|
||||||
|
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
||||||
|
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
BM25 vs DiskANN Baselines
|
||||||
|
|
||||||
|
```bash
|
||||||
|
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
|
||||||
|
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
|
||||||
|
```
|
||||||
|
|
||||||
|
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
|
||||||
|
- Machine-specific; results measured locally with the current repo.
|
||||||
|
|
||||||
|
DiskANN (NQ queries, search-only)
|
||||||
|
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
|
||||||
|
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
|
||||||
|
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
|
||||||
|
|
||||||
|
BM25
|
||||||
|
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
|
||||||
|
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
|
||||||
|
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
|
||||||
|
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.
|
||||||
|
After Width: | Height: | Size: 1.3 KiB |
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "pyserini"
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
# sudo pacman -S jdk21-openjdk
|
||||||
|
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
|
||||||
|
# sudo archlinux-java status
|
||||||
|
# sudo archlinux-java set java-21-openjdk
|
||||||
|
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
|
||||||
|
# fish_add_path --global $JAVA_HOME/bin
|
||||||
|
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
|
||||||
|
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(path: str, limit: int | None) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
|
||||||
|
_, ext = os.path.splitext(path)
|
||||||
|
if ext.lower() in {".jsonl", ".json"}:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Not strict JSONL? treat the whole line as the query
|
||||||
|
queries.append(line)
|
||||||
|
continue
|
||||||
|
q = obj.get("query") or obj.get("text") or obj.get("question")
|
||||||
|
if q:
|
||||||
|
queries.append(str(q))
|
||||||
|
else:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
s = line.strip()
|
||||||
|
if s:
|
||||||
|
queries.append(s)
|
||||||
|
|
||||||
|
if limit is not None and limit > 0:
|
||||||
|
queries = queries[:limit]
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def percentile(values: list[float], p: float) -> float:
|
||||||
|
if not values:
|
||||||
|
return 0.0
|
||||||
|
s = sorted(values)
|
||||||
|
k = (len(s) - 1) * (p / 100.0)
|
||||||
|
f = int(k)
|
||||||
|
c = min(f + 1, len(s) - 1)
|
||||||
|
if f == c:
|
||||||
|
return s[f]
|
||||||
|
return s[f] + (s[c] - s[f]) * (k - f)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--bm25-index",
|
||||||
|
default="benchmarks/data/indices/bm25_index",
|
||||||
|
help="Path to Pyserini Lucene index directory",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--queries",
|
||||||
|
default="benchmarks/data/queries/nq_open.jsonl",
|
||||||
|
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
|
||||||
|
)
|
||||||
|
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
|
||||||
|
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
|
||||||
|
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
|
||||||
|
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
|
||||||
|
)
|
||||||
|
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyserini.search.lucene import LuceneSearcher
|
||||||
|
except Exception:
|
||||||
|
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not os.path.isdir(args.bm25_index):
|
||||||
|
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
queries = load_queries(args.queries, args.limit)
|
||||||
|
if not queries:
|
||||||
|
print("No queries loaded.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||||
|
print(f"Opening BM25 index: {args.bm25_index}")
|
||||||
|
searcher = LuceneSearcher(args.bm25_index)
|
||||||
|
# Some builds of pyserini require explicit set_bm25; others ignore
|
||||||
|
try:
|
||||||
|
searcher.set_bm25(k1=args.k1, b=args.b)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
total_searches = 0
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for i in range(min(args.warmup, len(queries))):
|
||||||
|
_ = searcher.search(queries[i], k=args.k)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
for i, q in enumerate(queries):
|
||||||
|
t1 = time.time()
|
||||||
|
hits = searcher.search(q, k=args.k)
|
||||||
|
t2 = time.time()
|
||||||
|
latencies.append(t2 - t1)
|
||||||
|
total_searches += 1
|
||||||
|
|
||||||
|
if args.fetch_docs:
|
||||||
|
# Optional doc fetch to include I/O time
|
||||||
|
for h in hits:
|
||||||
|
try:
|
||||||
|
_ = searcher.doc(h.docid)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
print(f"Processed {i + 1}/{len(queries)} queries")
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
total_time = t1 - t0
|
||||||
|
|
||||||
|
if latencies:
|
||||||
|
avg = mean(latencies)
|
||||||
|
p50 = percentile(latencies, 50)
|
||||||
|
p90 = percentile(latencies, 90)
|
||||||
|
p95 = percentile(latencies, 95)
|
||||||
|
p99 = percentile(latencies, 99)
|
||||||
|
qps = total_searches / total_time if total_time > 0 else 0.0
|
||||||
|
else:
|
||||||
|
avg = p50 = p90 = p95 = p99 = qps = 0.0
|
||||||
|
|
||||||
|
print("BM25 Latency Report")
|
||||||
|
print(f" queries: {total_searches}")
|
||||||
|
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
|
||||||
|
print(f" avg per query: {avg:.6f} s")
|
||||||
|
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
|
||||||
|
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
|
||||||
|
|
||||||
|
if args.report:
|
||||||
|
payload = {
|
||||||
|
"queries": total_searches,
|
||||||
|
"k": args.k,
|
||||||
|
"k1": args.k1,
|
||||||
|
"b": args.b,
|
||||||
|
"avg_s": avg,
|
||||||
|
"p50_s": p50,
|
||||||
|
"p90_s": p90,
|
||||||
|
"p95_s": p95,
|
||||||
|
"p99_s": p99,
|
||||||
|
"total_time_s": total_time,
|
||||||
|
"qps": qps,
|
||||||
|
"index_dir": os.path.abspath(args.bm25_index),
|
||||||
|
"fetch_docs": bool(args.fetch_docs),
|
||||||
|
}
|
||||||
|
with open(args.report, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(payload, f, indent=2)
|
||||||
|
print(f"Saved report to {args.report}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "leann-backend-diskann"
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(path: Path, limit: int | None) -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
obj = json.loads(line)
|
||||||
|
out.append(obj["query"])
|
||||||
|
if limit and len(out) >= limit:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="DiskANN baseline on real NQ queries (search-only timing)"
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
default="benchmarks/data/indices/diskann_rpj_wiki",
|
||||||
|
help="Directory containing DiskANN files",
|
||||||
|
)
|
||||||
|
ap.add_argument("--index-prefix", default="ann")
|
||||||
|
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
|
||||||
|
ap.add_argument("--num-queries", type=int, default=200)
|
||||||
|
ap.add_argument("--top-k", type=int, default=10)
|
||||||
|
ap.add_argument("--complexity", type=int, default=62)
|
||||||
|
ap.add_argument("--threads", type=int, default=1)
|
||||||
|
ap.add_argument("--beam-width", type=int, default=1)
|
||||||
|
ap.add_argument("--cache-mechanism", type=int, default=2)
|
||||||
|
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
index_dir = Path(args.index_dir).resolve()
|
||||||
|
if not index_dir.is_dir():
|
||||||
|
raise SystemExit(f"Index dir not found: {index_dir}")
|
||||||
|
|
||||||
|
qpath = Path(args.queries_file).resolve()
|
||||||
|
if not qpath.exists():
|
||||||
|
raise SystemExit(f"Queries file not found: {qpath}")
|
||||||
|
|
||||||
|
queries = load_queries(qpath, args.num_queries)
|
||||||
|
print(f"Loaded {len(queries)} queries from {qpath}")
|
||||||
|
|
||||||
|
# Compute embeddings once (exclude from timing)
|
||||||
|
from leann.api import compute_embeddings as _compute
|
||||||
|
|
||||||
|
embs = _compute(
|
||||||
|
queries,
|
||||||
|
model_name="facebook/contriever-msmarco",
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
if embs.ndim != 2:
|
||||||
|
raise SystemExit("Embedding compute failed or returned wrong shape")
|
||||||
|
|
||||||
|
# Build searcher
|
||||||
|
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
|
||||||
|
|
||||||
|
index_prefix_path = str(index_dir / args.index_prefix)
|
||||||
|
searcher = _DiskannSearcher(
|
||||||
|
index_prefix_path,
|
||||||
|
num_threads=int(args.threads),
|
||||||
|
cache_mechanism=int(args.cache_mechanism),
|
||||||
|
num_nodes_to_cache=int(args.num_nodes_to_cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup (not timed)
|
||||||
|
_ = searcher.search(
|
||||||
|
embs[0:1],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
batch_recompute=False,
|
||||||
|
dedup_node_dis=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timed loop
|
||||||
|
times: list[float] = []
|
||||||
|
for i in range(embs.shape[0]):
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
embs[i : i + 1],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
batch_recompute=False,
|
||||||
|
dedup_node_dis=False,
|
||||||
|
)
|
||||||
|
times.append(time.time() - t0)
|
||||||
|
|
||||||
|
times_sorted = sorted(times)
|
||||||
|
avg = float(sum(times) / len(times))
|
||||||
|
p50 = times_sorted[len(times) // 2]
|
||||||
|
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
|
||||||
|
|
||||||
|
print("\nDiskANN (NQ, search-only) Report")
|
||||||
|
print(f" queries: {len(times)}")
|
||||||
|
print(
|
||||||
|
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
|
||||||
|
)
|
||||||
|
print(f" avg per query: {avg:.6f} s")
|
||||||
|
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
|
||||||
|
print(f" QPS: {1.0 / avg:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,14 +3,15 @@
|
|||||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import psutil
|
|
||||||
import gc
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import psutil
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -61,7 +62,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[sys.executable, "examples/faiss_only.py"],
|
[sys.executable, "benchmarks/faiss_only.py"],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300,
|
timeout=300,
|
||||||
@@ -83,9 +84,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if "Peak Memory:" in line:
|
if "Peak Memory:" in line:
|
||||||
peak_memory = float(
|
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"peak_memory": peak_memory}
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
@@ -111,13 +110,12 @@ def test_leann_hnsw():
|
|||||||
|
|
||||||
tracker.checkpoint("After imports")
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
# Load and parse documents
|
# Load and parse documents
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"examples/data",
|
"data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
@@ -135,6 +133,7 @@ def test_leann_hnsw():
|
|||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
print(f"Total number of chunks: {len(all_texts)}")
|
||||||
|
|
||||||
tracker.checkpoint("After text chunking")
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
@@ -201,11 +200,9 @@ def test_leann_hnsw():
|
|||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
tracker.checkpoint("After searcher loading")
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("Running search queries...")
|
print("Running search queries...")
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
@@ -303,21 +300,15 @@ def main():
|
|||||||
|
|
||||||
print("\nLEANN vs Faiss Performance:")
|
print("\nLEANN vs Faiss Performance:")
|
||||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
print(
|
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Storage comparison
|
# Storage comparison
|
||||||
if leann_storage_size > faiss_storage_size:
|
if leann_storage_size > faiss_storage_size:
|
||||||
storage_ratio = leann_storage_size / faiss_storage_size
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
print(
|
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
|
||||||
)
|
|
||||||
elif faiss_storage_size > leann_storage_size:
|
elif faiss_storage_size > leann_storage_size:
|
||||||
storage_ratio = faiss_storage_size / leann_storage_size
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
print(
|
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(" Storage Size: similar")
|
print(" Storage Size: similar")
|
||||||
else:
|
else:
|
||||||
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
DiskANN vs HNSW Search Performance Comparison
|
||||||
|
|
||||||
|
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||||
|
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||||
|
- HNSW: With recompute enabled (is_recompute=True)
|
||||||
|
- Tests performance across different dataset sizes
|
||||||
|
- Measures search latency, recall, and index size
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import multiprocessing as mp
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
||||||
|
try:
|
||||||
|
mp.set_start_method("fork", force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_texts(n_docs: int) -> list[str]:
|
||||||
|
"""Create synthetic test documents for benchmarking."""
|
||||||
|
np.random.seed(42)
|
||||||
|
topics = [
|
||||||
|
"machine learning and artificial intelligence",
|
||||||
|
"natural language processing and text analysis",
|
||||||
|
"computer vision and image recognition",
|
||||||
|
"data science and statistical analysis",
|
||||||
|
"deep learning and neural networks",
|
||||||
|
"information retrieval and search engines",
|
||||||
|
"database systems and data management",
|
||||||
|
"software engineering and programming",
|
||||||
|
"cybersecurity and network protection",
|
||||||
|
"cloud computing and distributed systems",
|
||||||
|
]
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
for i in range(n_docs):
|
||||||
|
topic = topics[i % len(topics)]
|
||||||
|
variation = np.random.randint(1, 100)
|
||||||
|
text = (
|
||||||
|
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||||
|
f"Additional information about {topic} with details and examples. "
|
||||||
|
f"Technical discussion of {topic} including implementation aspects."
|
||||||
|
)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_backend(
|
||||||
|
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Benchmark a specific backend with the given configuration."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Measure index size
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||||
|
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||||
|
size_mb = total_size / (1024 * 1024)
|
||||||
|
|
||||||
|
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||||
|
|
||||||
|
# Search benchmark
|
||||||
|
print("🔍 Running search benchmark...")
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
results = searcher.search(query, top_k=5)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
all_results.append(results)
|
||||||
|
|
||||||
|
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||||
|
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||||
|
|
||||||
|
# Check for valid scores (detect -inf issues)
|
||||||
|
all_scores = [
|
||||||
|
result.score
|
||||||
|
for results in all_results
|
||||||
|
for result in results
|
||||||
|
if result.score is not None
|
||||||
|
]
|
||||||
|
valid_scores = [
|
||||||
|
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||||
|
]
|
||||||
|
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||||
|
|
||||||
|
# Clean up (ensure embedding server shutdown and object GC)
|
||||||
|
try:
|
||||||
|
if hasattr(searcher, "cleanup"):
|
||||||
|
searcher.cleanup()
|
||||||
|
del searcher
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"build_time": build_time,
|
||||||
|
"avg_search_time_ms": avg_search_time,
|
||||||
|
"index_size_mb": size_mb,
|
||||||
|
"score_validity_rate": score_validity_rate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||||
|
"""Run performance comparison between DiskANN and HNSW."""
|
||||||
|
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||||
|
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
texts = create_test_texts(n_docs)
|
||||||
|
test_queries = [
|
||||||
|
"machine learning algorithms",
|
||||||
|
"natural language processing",
|
||||||
|
"computer vision techniques",
|
||||||
|
"data analysis methods",
|
||||||
|
"neural network architectures",
|
||||||
|
"database query optimization",
|
||||||
|
"software development practices",
|
||||||
|
"security vulnerabilities",
|
||||||
|
"cloud infrastructure",
|
||||||
|
"distributed computing",
|
||||||
|
][:n_queries]
|
||||||
|
|
||||||
|
# HNSW benchmark
|
||||||
|
hnsw_results = benchmark_backend(
|
||||||
|
backend_name="hnsw",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable recompute for fair comparison
|
||||||
|
"M": 16,
|
||||||
|
"efConstruction": 200,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# DiskANN benchmark
|
||||||
|
diskann_results = benchmark_backend(
|
||||||
|
backend_name="diskann",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable graph partitioning
|
||||||
|
"num_neighbors": 32,
|
||||||
|
"search_list_size": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
print("\n📈 Performance Comparison Results")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||||
|
print(f"{'-' * 60}")
|
||||||
|
|
||||||
|
# Build time comparison
|
||||||
|
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||||
|
print(
|
||||||
|
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search time comparison
|
||||||
|
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||||
|
print(
|
||||||
|
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index size comparison
|
||||||
|
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||||
|
print(
|
||||||
|
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score validity
|
||||||
|
print(
|
||||||
|
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print("\n🎯 Summary:")
|
||||||
|
if search_speedup > 1:
|
||||||
|
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||||
|
else:
|
||||||
|
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||||
|
|
||||||
|
if size_ratio > 1:
|
||||||
|
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||||
|
else:
|
||||||
|
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle help request
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||||
|
print()
|
||||||
|
print("Arguments:")
|
||||||
|
print(" n_docs Number of documents to index (default: 500)")
|
||||||
|
print(" n_queries Number of test queries to run (default: 10)")
|
||||||
|
print()
|
||||||
|
print("Examples:")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||||
|
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||||
|
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||||
|
print()
|
||||||
|
|
||||||
|
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Benchmark interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Benchmark failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
finally:
|
||||||
|
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
||||||
|
try:
|
||||||
|
gc.collect()
|
||||||
|
print("\n🧹 Cleanup completed")
|
||||||
|
# Flush stdio to ensure message is visible before hard-exit
|
||||||
|
try:
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
|
_sys.stdout.flush()
|
||||||
|
_sys.stderr.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
||||||
|
import os as _os
|
||||||
|
|
||||||
|
_os._exit(0)
|
||||||
141
benchmarks/enron_emails/README.md
Normal file
141
benchmarks/enron_emails/README.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Enron Emails Benchmark
|
||||||
|
|
||||||
|
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
|
||||||
|
|
||||||
|
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||||
|
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||||
|
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
|
||||||
|
|
||||||
|
## Layout
|
||||||
|
|
||||||
|
benchmarks/enron_emails/
|
||||||
|
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||||
|
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
|
||||||
|
- data/: Generated passages, queries, embeddings-related files
|
||||||
|
- baseline/: FAISS Flat baseline files
|
||||||
|
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
1) Prepare the data and index
|
||||||
|
|
||||||
|
cd benchmarks/enron_emails
|
||||||
|
python setup_enron_emails.py --data-dir data
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||||
|
Alternatively, pass a local path to `--emails-csv`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||||
|
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||||
|
|
||||||
|
2) Run recall evaluation (Stage 2)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||||
|
|
||||||
|
3) Complexity sweep (Stage 3)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||||
|
|
||||||
|
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||||
|
|
||||||
|
4) Index comparison (Stage 4)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
|
||||||
|
|
||||||
|
5) Generation evaluation (Stage 5)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
|
||||||
|
|
||||||
|
6) Combined index + generation evaluation (Stages 4+5, recommended)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||||
|
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
|
||||||
|
- `--baseline-dir` defaults to `baseline`
|
||||||
|
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||||
|
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
|
||||||
|
- `--model-name` defaults to `Qwen/Qwen3-8B`
|
||||||
|
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||||
|
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||||
|
- --baseline-dir baseline (where FAISS baseline lives)
|
||||||
|
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
|
||||||
|
- --llm-backend hf|vllm (LLM backend for generation)
|
||||||
|
- --model-name Qwen/Qwen3-8B (LLM model for generation)
|
||||||
|
- --max-queries 1000 (limit number of queries for evaluation)
|
||||||
|
|
||||||
|
## Files Produced
|
||||||
|
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||||
|
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||||
|
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||||
|
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
|
||||||
|
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||||
|
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
|
||||||
|
|
||||||
|
## Stages Summary
|
||||||
|
|
||||||
|
- Stage 2 (Recall@3):
|
||||||
|
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||||
|
- Compact index runs with `recompute_embeddings=True`.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search for Complexity):
|
||||||
|
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||||
|
|
||||||
|
- Stage 4 (Index Comparison):
|
||||||
|
- Reports .index-only sizes for compact vs non-compact.
|
||||||
|
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||||
|
- Stores retrieval results for Stage 5 generation evaluation.
|
||||||
|
- Fails fast if compact recompute cannot run.
|
||||||
|
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||||
|
- First from the current run (when running `--stage all`), otherwise
|
||||||
|
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||||
|
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||||
|
|
||||||
|
- Stage 5 (Generation Evaluation):
|
||||||
|
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
|
||||||
|
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
|
||||||
|
- Measures generation timing separately from search timing.
|
||||||
|
- Requires Stage 4 results (no additional searching performed).
|
||||||
|
|
||||||
|
## Example Results
|
||||||
|
|
||||||
|
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search):
|
||||||
|
- Minimal complexity achieving 90% Recall@3: 88
|
||||||
|
- Sampled points:
|
||||||
|
- C=8 → 59.9% Recall@3
|
||||||
|
- C=72 → 89.4% Recall@3
|
||||||
|
- C=88 → 90.2% Recall@3
|
||||||
|
- C=96 → 90.7% Recall@3
|
||||||
|
- C=112 → 91.1% Recall@3
|
||||||
|
- C=136 → 91.3% Recall@3
|
||||||
|
- C=256 → 92.0% Recall@3
|
||||||
|
|
||||||
|
- Stage 4 (Index Sizes, .index only):
|
||||||
|
- Compact: ~2.2 MB
|
||||||
|
- Non-compact: ~82.0 MB
|
||||||
|
- Storage saving by compact: ~97.3%
|
||||||
|
|
||||||
|
- Stage 4 (Search Timing, 988 queries, complexity=88):
|
||||||
|
- Non-compact (no recompute): ~0.0075 s avg per query
|
||||||
|
- Compact (with recompute): ~1.981 s avg per query
|
||||||
|
- Speed ratio (non-compact/compact): ~0.0038x
|
||||||
|
|
||||||
|
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
|
||||||
|
- Average generation time: ~22.302 s per query
|
||||||
|
- Total queries processed: 988
|
||||||
|
- LLM backend: HuggingFace transformers
|
||||||
|
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||||
|
|
||||||
|
Full JSON output is saved by the script (see `--output`), e.g.:
|
||||||
|
`benchmarks/enron_emails/results_enron_stage45.json`.
|
||||||
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
downloads/
|
||||||
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
@@ -0,0 +1,614 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||||
|
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||||
|
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||||
|
On errors, fail fast without fallbacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Compute query embedding with the same model/mode as the index
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS Flat ground truth
|
||||||
|
n = q_emb.shape[0]
|
||||||
|
k = 3
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(q_emb),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN (may require embedding server depending on index configuration)
|
||||||
|
results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {r.id for r in results}
|
||||||
|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN: {list(test_ids)}")
|
||||||
|
print(f" ∩: {list(intersection)}")
|
||||||
|
|
||||||
|
avg = total_recall / max(1, len(queries))
|
||||||
|
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||||
|
return avg
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class EnronEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
if "query" in data:
|
||||||
|
queries.append(data["query"])
|
||||||
|
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||||
|
|
||||||
|
sizes["index_only_mb"] = (
|
||||||
|
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source and embedding model
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
# Load all passages and ids
|
||||||
|
ids: list[str] = []
|
||||||
|
texts: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
texts.append(data["text"])
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
meta["embedding_model"],
|
||||||
|
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False,
|
||||||
|
is_compact=False,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare search speed for non-compact vs compact indexes."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
results: dict = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||||
|
}
|
||||||
|
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
# Non-compact (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Compact (with recompute). Fail fast if it cannot run.
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
docs = compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
results["compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Store retrieval results for Stage 5
|
||||||
|
results["retrieval_results"].append(
|
||||||
|
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||||
|
)
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
if results["non_compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
if results["compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
if results["avg_search_times"].get("compact", 0) > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = 0.0
|
||||||
|
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_complexity(
|
||||||
|
self,
|
||||||
|
recall_eval: "RecallEvaluator",
|
||||||
|
queries: list[str],
|
||||||
|
target: float = 0.90,
|
||||||
|
c_min: int = 8,
|
||||||
|
c_max: int = 256,
|
||||||
|
max_iters: int = 10,
|
||||||
|
recompute: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||||
|
|
||||||
|
def round_c(x: int) -> int:
|
||||||
|
# snap to multiple of 8 like other benches typically do
|
||||||
|
return max(1, int((x + 7) // 8) * 8)
|
||||||
|
|
||||||
|
metrics: list[dict] = []
|
||||||
|
|
||||||
|
lo = round_c(c_min)
|
||||||
|
hi = round_c(c_max)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||||
|
r_lo = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=lo, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
cap = 1024
|
||||||
|
while r_hi < target and hi < cap:
|
||||||
|
lo = hi
|
||||||
|
r_lo = r_hi
|
||||||
|
hi = round_c(hi * 2)
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
if r_hi < target:
|
||||||
|
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||||
|
print("📈 Observations:")
|
||||||
|
for m in metrics:
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||||
|
|
||||||
|
# Binary search within [lo, hi]
|
||||||
|
best = hi
|
||||||
|
iters = 0
|
||||||
|
while lo < hi and iters < max_iters:
|
||||||
|
mid = round_c((lo + hi) // 2)
|
||||||
|
r_mid = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=mid, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||||
|
if r_mid >= target:
|
||||||
|
best = mid
|
||||||
|
hi = mid
|
||||||
|
else:
|
||||||
|
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||||
|
iters += 1
|
||||||
|
|
||||||
|
print("📈 Binary search results (sampled points):")
|
||||||
|
# Print unique complexity entries ordered by complexity
|
||||||
|
for m in sorted(
|
||||||
|
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||||
|
):
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||||
|
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "5", "all", "45"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||||
|
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Resolve queries file: if default path not found, fall back to index's directory
|
||||||
|
if not os.path.exists(args.queries):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
idx_dir = Path(args.index).parent
|
||||||
|
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||||
|
if fallback_q.exists():
|
||||||
|
args.queries = str(fallback_q)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
results_out: dict = {}
|
||||||
|
|
||||||
|
if args.stage in ("2", "all"):
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[:10]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||||
|
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("3", "all"):
|
||||||
|
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[: args.max_queries]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact evaluator for binary search with recompute=False
|
||||||
|
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
sweep = enron_eval.evaluate_complexity(
|
||||||
|
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||||
|
)
|
||||||
|
results_out["stage3"] = sweep
|
||||||
|
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"stage3": sweep}, f, indent=2)
|
||||||
|
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||||
|
evaluator_nc.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 3 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("4", "all", "45"):
|
||||||
|
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
test_q = queries[: min(args.max_queries, len(queries))]
|
||||||
|
|
||||||
|
current_sizes = enron_eval.analyze_index_sizes()
|
||||||
|
# Build non-compact index for comparison (no fallback)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||||
|
nc_eval = EnronEvaluator(non_compact_path)
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_sizes.get("index_only_mb", 0) > 0
|
||||||
|
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||||
|
):
|
||||||
|
storage_saving_percent = max(
|
||||||
|
0.0,
|
||||||
|
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_saving_percent = 0.0
|
||||||
|
|
||||||
|
if args.complexity is None:
|
||||||
|
# Prefer in-session Stage 3 result
|
||||||
|
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||||
|
complexity = results_out["stage3"]["best_complexity"]
|
||||||
|
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||||
|
else:
|
||||||
|
# Try to load last saved Stage 3 result near index
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
if default_stage3_path.exists():
|
||||||
|
with open(default_stage3_path, encoding="utf-8") as f:
|
||||||
|
prev = json.load(f)
|
||||||
|
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||||
|
if complexity is None:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
||||||
|
)
|
||||||
|
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||||
|
else:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
comp = enron_eval.compare_index_performance(
|
||||||
|
non_compact_path, args.index, test_q, complexity=complexity
|
||||||
|
)
|
||||||
|
results_out["stage4"] = {
|
||||||
|
"current_index": current_sizes,
|
||||||
|
"non_compact_index": non_compact_sizes,
|
||||||
|
"storage_saving_percent": storage_saving_percent,
|
||||||
|
"performance_comparison": comp,
|
||||||
|
}
|
||||||
|
nc_eval.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("5", "all"):
|
||||||
|
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||||
|
|
||||||
|
# Check if Stage 4 results exist
|
||||||
|
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||||
|
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||||
|
print("💡 Run Stage 4 first or use --stage all")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||||
|
if not retrieval_results:
|
||||||
|
print("❌ No retrieval results found from Stage 4")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||||
|
|
||||||
|
# Load LLM
|
||||||
|
try:
|
||||||
|
if args.llm_backend == "hf":
|
||||||
|
tokenizer, model = load_hf_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_hf(tokenizer, model, prompt)
|
||||||
|
else: # vllm
|
||||||
|
llm, sampling_params = load_vllm_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_vllm(llm, sampling_params, prompt)
|
||||||
|
|
||||||
|
# Run generation using stored retrieval results
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_utils import create_prompt
|
||||||
|
|
||||||
|
generation_times = []
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
print("🤖 Running generation on pre-retrieved results...")
|
||||||
|
for i, item in enumerate(retrieval_results):
|
||||||
|
query = item["query"]
|
||||||
|
retrieved_docs = item["retrieved_docs"]
|
||||||
|
|
||||||
|
# Prepare context from retrieved docs
|
||||||
|
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||||
|
prompt = create_prompt(context, query, "emails")
|
||||||
|
|
||||||
|
# Time generation only
|
||||||
|
gen_start = time.time()
|
||||||
|
response = llm_func(prompt)
|
||||||
|
gen_time = time.time() - gen_start
|
||||||
|
|
||||||
|
generation_times.append(gen_time)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||||
|
|
||||||
|
print("\n📊 Generation Results:")
|
||||||
|
print(f" Total Queries: {len(retrieval_results)}")
|
||||||
|
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||||
|
print(" (Search time from Stage 4)")
|
||||||
|
|
||||||
|
results_out["stage5"] = {
|
||||||
|
"total_queries": len(retrieval_results),
|
||||||
|
"avg_generation_time": avg_gen_time,
|
||||||
|
"generation_times": generation_times,
|
||||||
|
"responses": responses,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Show sample results
|
||||||
|
print("\n📝 Sample Results:")
|
||||||
|
for i in range(min(3, len(retrieval_results))):
|
||||||
|
query = retrieval_results[i]["query"]
|
||||||
|
response = responses[i]
|
||||||
|
print(f" Q{i + 1}: {query[:60]}...")
|
||||||
|
print(f" A{i + 1}: {response[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Generation evaluation failed: {e}")
|
||||||
|
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||||
|
|
||||||
|
print("✅ Stage 5 completed!\n")
|
||||||
|
|
||||||
|
if args.output and results_out:
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results_out, f, indent=2)
|
||||||
|
print(f"📝 Saved results to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Setup Script
|
||||||
|
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from email import message_from_string
|
||||||
|
from email.policy import default
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class EnronSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||||
|
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||||
|
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
self.downloads_dir = self.data_dir / "downloads"
|
||||||
|
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Dataset acquisition
|
||||||
|
# ----------------------------
|
||||||
|
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||||
|
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||||
|
if emails_csv:
|
||||||
|
p = Path(emails_csv)
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||||
|
|
||||||
|
api = KaggleApi()
|
||||||
|
api.authenticate()
|
||||||
|
api.dataset_download_files(
|
||||||
|
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||||
|
)
|
||||||
|
candidate = self.downloads_dir / "emails.csv"
|
||||||
|
if candidate.exists():
|
||||||
|
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||||
|
return str(candidate)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Data preparation
|
||||||
|
# ----------------------------
|
||||||
|
@staticmethod
|
||||||
|
def _extract_message_id(raw_email: str) -> str:
|
||||||
|
msg = message_from_string(raw_email, policy=default)
|
||||||
|
val = msg.get("Message-ID", "")
|
||||||
|
if val.startswith("<") and val.endswith(">"):
|
||||||
|
val = val[1:-1]
|
||||||
|
return val or ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||||
|
parts = raw_email.split("\n\n", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[0].strip(), parts[1].strip()
|
||||||
|
# Heuristic fallback
|
||||||
|
first_lines = raw_email.splitlines()
|
||||||
|
if first_lines and ":" in first_lines[0]:
|
||||||
|
return raw_email.strip(), ""
|
||||||
|
return "", raw_email.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||||
|
text = (text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
if chunk_words <= 0:
|
||||||
|
return [text]
|
||||||
|
words = text.split()
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
limit = len(words)
|
||||||
|
if not keep_last:
|
||||||
|
limit = (len(words) // chunk_words) * chunk_words
|
||||||
|
if limit == 0:
|
||||||
|
return []
|
||||||
|
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||||
|
return [c for c in (s.strip() for s in chunks) if c]
|
||||||
|
|
||||||
|
def _iter_passages_from_csv(
|
||||||
|
self,
|
||||||
|
emails_csv: Path,
|
||||||
|
chunk_words: int = 256,
|
||||||
|
keep_last_header: bool = True,
|
||||||
|
keep_last_body: bool = True,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> Iterable[dict]:
|
||||||
|
with open(emails_csv, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
count = 0
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if max_emails is not None and count >= max_emails:
|
||||||
|
break
|
||||||
|
|
||||||
|
raw_message = row.get("message", "")
|
||||||
|
email_file_id = row.get("file", "")
|
||||||
|
|
||||||
|
if not raw_message.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
message_id = self._extract_message_id(raw_message)
|
||||||
|
if not message_id:
|
||||||
|
# Fallback ID based on CSV position and file path
|
||||||
|
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||||
|
message_id = f"enron_{i}_{safe_file}"
|
||||||
|
|
||||||
|
header, body = self._split_header_body(raw_message)
|
||||||
|
|
||||||
|
# Header chunks
|
||||||
|
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": True,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Body chunks
|
||||||
|
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": False,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Build LEANN index and FAISS baseline
|
||||||
|
# ----------------------------
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
emails_csv: Optional[str],
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
chunk_words: int = 256,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||||
|
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream passages and add to builder
|
||||||
|
preview_written = 0
|
||||||
|
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||||
|
for p in self._iter_passages_from_csv(
|
||||||
|
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||||
|
):
|
||||||
|
builder.add_text(p["text"], metadata=p["metadata"])
|
||||||
|
if preview_written < 200:
|
||||||
|
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||||
|
preview_written += 1
|
||||||
|
|
||||||
|
print(f"🔨 Building index at {self.index_path}...")
|
||||||
|
builder.build_index(str(self.index_path))
|
||||||
|
print("✅ LEANN index built!")
|
||||||
|
return str(self.index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||||
|
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read meta for passage source and embedding model
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
# Load passages from builder output so IDs match LEANN
|
||||||
|
passages: list[str] = []
|
||||||
|
passage_ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"]) # builder-assigned ID
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||||
|
print(f"🤖 Embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS IndexFlatIP
|
||||||
|
dim = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dim)
|
||||||
|
emb_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||||
|
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as pf:
|
||||||
|
pickle.dump(passage_ids, pf)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved: {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Queries (optional): prepare evaluation queries file
|
||||||
|
# ----------------------------
|
||||||
|
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||||
|
print(
|
||||||
|
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to load dataset: {e}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
kept = 0
|
||||||
|
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||||
|
for i, item in enumerate(ds):
|
||||||
|
how_realistic = float(item.get("how_realistic", 0.0))
|
||||||
|
if how_realistic < min_realism:
|
||||||
|
continue
|
||||||
|
qid = str(item.get("id", f"enron_q_{i}"))
|
||||||
|
query = item.get("question", "")
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
record = {
|
||||||
|
"id": qid,
|
||||||
|
"query": query,
|
||||||
|
# For reference only, not used in recall metric below
|
||||||
|
"gt_message_ids": item.get("message_ids", []),
|
||||||
|
}
|
||||||
|
out.write(json.dumps(record) + "\n")
|
||||||
|
kept += 1
|
||||||
|
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--emails-csv",
|
||||||
|
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model for LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||||
|
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||||
|
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
setup = EnronSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
emails_csv=args.emails_csv,
|
||||||
|
backend=args.backend,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
chunk_words=args.chunk_words,
|
||||||
|
max_emails=args.max_emails,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS baseline from the same passages & embeddings
|
||||||
|
setup.build_faiss_flat_baseline(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LEANN index build and baseline")
|
||||||
|
|
||||||
|
# Queries file (optional)
|
||||||
|
if not args.skip_queries:
|
||||||
|
setup.prepare_queries()
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping query preparation")
|
||||||
|
|
||||||
|
print("\n🎉 Enron Emails setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("Next steps:")
|
||||||
|
print(
|
||||||
|
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Test only Faiss HNSW"""
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import gc
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
def get_memory_usage():
|
||||||
@@ -37,20 +37,20 @@ def main():
|
|||||||
import faiss
|
import faiss
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Faiss is not installed.")
|
print("Faiss is not installed.")
|
||||||
print("Please install it with `uv pip install faiss-cpu`")
|
print(
|
||||||
|
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from llama_index.core import (
|
from llama_index.core import (
|
||||||
SimpleDirectoryReader,
|
|
||||||
VectorStoreIndex,
|
|
||||||
StorageContext,
|
|
||||||
Settings,
|
Settings,
|
||||||
node_parser,
|
SimpleDirectoryReader,
|
||||||
Document,
|
StorageContext,
|
||||||
|
VectorStoreIndex,
|
||||||
)
|
)
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
tracker = MemoryTracker("Faiss HNSW")
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
tracker.checkpoint("Initial")
|
tracker.checkpoint("Initial")
|
||||||
@@ -65,7 +65,7 @@ def main():
|
|||||||
tracker.checkpoint("After Faiss index creation")
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"examples/data",
|
"data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
@@ -90,8 +90,9 @@ def main():
|
|||||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
)
|
)
|
||||||
from llama_index.core import load_index_from_storage
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
index = load_index_from_storage(storage_context=storage_context)
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
print(f"Index loaded from ./storage_faiss")
|
print("Index loaded from ./storage_faiss")
|
||||||
tracker.checkpoint("After loading existing index")
|
tracker.checkpoint("After loading existing index")
|
||||||
index_loaded = True
|
index_loaded = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -99,6 +100,7 @@ def main():
|
|||||||
print("Cleaning up corrupted index and building new one...")
|
print("Cleaning up corrupted index and building new one...")
|
||||||
# Clean up corrupted index
|
# Clean up corrupted index
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if os.path.exists("./storage_faiss"):
|
if os.path.exists("./storage_faiss"):
|
||||||
shutil.rmtree("./storage_faiss")
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
@@ -109,9 +111,7 @@ def main():
|
|||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents,
|
documents, storage_context=storage_context, transformations=[node_parser]
|
||||||
storage_context=storage_context,
|
|
||||||
transformations=[node_parser]
|
|
||||||
)
|
)
|
||||||
tracker.checkpoint("After index building")
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
@@ -127,7 +127,7 @@ def main():
|
|||||||
|
|
||||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
115
benchmarks/financebench/README.md
Normal file
115
benchmarks/financebench/README.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# FinanceBench Benchmark for LEANN-RAG
|
||||||
|
|
||||||
|
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||||
|
- **Questions**: 150 financial Q&A examples
|
||||||
|
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||||
|
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||||
|
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
benchmarks/financebench/
|
||||||
|
├── setup_financebench.py # Downloads PDFs and builds index
|
||||||
|
├── evaluate_financebench.py # Intelligent evaluation script
|
||||||
|
├── data/
|
||||||
|
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||||
|
│ ├── pdfs/ # Downloaded financial documents
|
||||||
|
│ └── index/ # LEANN indexes
|
||||||
|
│ └── financebench_full_hnsw.leann
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### 1. Setup (Download & Build Index)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd benchmarks/financebench
|
||||||
|
python setup_financebench.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Download the 150 Q&A examples
|
||||||
|
- Download all 368 PDF documents (parallel processing)
|
||||||
|
- Build a LEANN index from 53K+ text chunks
|
||||||
|
- Verify setup with test query
|
||||||
|
|
||||||
|
### 2. Evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic retrieval evaluation
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||||
|
|
||||||
|
|
||||||
|
# RAG generation evaluation with Qwen3-8B
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Methods
|
||||||
|
|
||||||
|
### Retrieval Evaluation
|
||||||
|
Uses intelligent matching with three strategies:
|
||||||
|
1. **Exact text overlap** - Direct substring matches
|
||||||
|
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||||
|
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||||
|
|
||||||
|
### QA Evaluation
|
||||||
|
LLM-based answer evaluation using GPT-4o:
|
||||||
|
- Handles numerical rounding and equivalent representations
|
||||||
|
- Considers fractions, percentages, and decimal equivalents
|
||||||
|
- Evaluates semantic meaning rather than exact text match
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
|
||||||
|
|
||||||
|
**Retrieval Metrics:**
|
||||||
|
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
|
||||||
|
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
|
||||||
|
- **Number Match Rate**: 120.7% (key financial figures matched)*
|
||||||
|
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
|
||||||
|
- **Average Search Time**: 0.097s
|
||||||
|
|
||||||
|
**QA Metrics:**
|
||||||
|
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
|
||||||
|
- **Average QA Time**: 4.71s (end-to-end response time)
|
||||||
|
|
||||||
|
**System Performance:**
|
||||||
|
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||||
|
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||||
|
|
||||||
|
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
|
||||||
|
|
||||||
|
### LEANN-RAG Generation Performance (Qwen3-8B)
|
||||||
|
|
||||||
|
- **Stage 4 (Index Comparison):**
|
||||||
|
- Compact Index: 5.0 MB
|
||||||
|
- Non-compact Index: 172.2 MB
|
||||||
|
- **Storage Saving**: 97.1%
|
||||||
|
- **Search Performance**:
|
||||||
|
- Non-compact (no recompute): 0.009s avg per query
|
||||||
|
- Compact (with recompute): 2.203s avg per query
|
||||||
|
- Speed ratio: 0.004x
|
||||||
|
|
||||||
|
**Generation Evaluation (20 queries, complexity=64):**
|
||||||
|
- **Average Search Time**: 1.638s per query
|
||||||
|
- **Average Generation Time**: 45.957s per query
|
||||||
|
- **LLM Backend**: HuggingFace transformers
|
||||||
|
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||||
|
- **Total Questions Processed**: 20
|
||||||
|
|
||||||
|
## Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use different backends
|
||||||
|
python setup_financebench.py --backend diskann
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||||
|
|
||||||
|
# Use different embedding models
|
||||||
|
python setup_financebench.py --embedding-model facebook/contriever
|
||||||
|
```
|
||||||
923
benchmarks/financebench/evaluate_financebench.py
Executable file
923
benchmarks/financebench/evaluate_financebench.py
Executable file
@@ -0,0 +1,923 @@
|
|||||||
|
"""
|
||||||
|
FinanceBench Evaluation Script - Modular Recall-based Evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
from leann import LeannChat, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Load FAISS flat baseline
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 for given queries at specified complexity"""
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = len(queries)
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Get ground truth: search with FAISS flat
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
query_embedding = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||||
|
n = query_embedding.shape[0] # Number of queries
|
||||||
|
k = 3 # Number of nearest neighbors
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(query_embedding),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the results
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN at specified complexity
|
||||||
|
test_results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {result.id for result in test_results}
|
||||||
|
|
||||||
|
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class FinanceBenchEvaluator:
|
||||||
|
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
|
||||||
|
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||||
|
|
||||||
|
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||||
|
"""Load FinanceBench dataset"""
|
||||||
|
data = []
|
||||||
|
with open(dataset_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data.append(json.loads(line))
|
||||||
|
|
||||||
|
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes with and without embeddings"""
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes...")
|
||||||
|
|
||||||
|
# Get all index-related files
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem # Remove .leann extension
|
||||||
|
|
||||||
|
sizes = {}
|
||||||
|
total_with_embeddings = 0
|
||||||
|
|
||||||
|
# Core index files
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||||
|
|
||||||
|
for file_path, name in [
|
||||||
|
(index_file, "index"),
|
||||||
|
(meta_file, "metadata"),
|
||||||
|
(passages_file, "passages_text"),
|
||||||
|
(passages_idx_file, "passages_index"),
|
||||||
|
]:
|
||||||
|
if file_path.exists():
|
||||||
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
|
sizes[name] = size_mb
|
||||||
|
total_with_embeddings += size_mb
|
||||||
|
|
||||||
|
else:
|
||||||
|
sizes[name] = 0
|
||||||
|
|
||||||
|
sizes["total_with_embeddings"] = total_with_embeddings
|
||||||
|
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
|
||||||
|
|
||||||
|
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
|
||||||
|
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
|
||||||
|
"""Create a compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Build compact index with same passages
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||||
|
is_compact=True, # Enable compact storage
|
||||||
|
**meta.get("backend_kwargs", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all passages
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||||
|
|
||||||
|
print(f"🔨 Building compact index at {compact_index_path}...")
|
||||||
|
builder.build_index(compact_index_path)
|
||||||
|
|
||||||
|
# Analyze the compact index size
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
|
||||||
|
compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
compact_sizes["index_type"] = "compact"
|
||||||
|
|
||||||
|
return compact_sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building non-compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Build non-compact index with same passages
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False, # Disable recompute (store embeddings)
|
||||||
|
is_compact=False, # Disable compact storage
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all passages
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||||
|
|
||||||
|
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
|
||||||
|
builder.build_index(non_compact_index_path)
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare performance between non-compact and compact indexes"""
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Test queries
|
||||||
|
test_queries = [item["question"] for item in test_data[:5]]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test non-compact index (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
query, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["non_compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Test compact index (with recompute)
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
query, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance ratio
|
||||||
|
if results["avg_search_times"]["compact"] > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||||
|
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_timing_breakdown(
|
||||||
|
self, data: list[dict], max_samples: Optional[int] = None
|
||||||
|
) -> dict:
|
||||||
|
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
|
||||||
|
if not self.chat or not self.openai_client:
|
||||||
|
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
|
||||||
|
return {
|
||||||
|
"total_questions": 0,
|
||||||
|
"avg_search_time": 0.0,
|
||||||
|
"avg_generation_time": 0.0,
|
||||||
|
"avg_total_time": 0.0,
|
||||||
|
"accuracy": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
|
||||||
|
|
||||||
|
if max_samples:
|
||||||
|
data = data[:max_samples]
|
||||||
|
print(f"📝 Using first {max_samples} samples for timing evaluation")
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
generation_times = []
|
||||||
|
total_times = []
|
||||||
|
correct_answers = 0
|
||||||
|
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
question = item["question"]
|
||||||
|
ground_truth = item["answer"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Hack: Monkey-patch the ask method to capture internal timing
|
||||||
|
original_ask = self.chat.ask
|
||||||
|
captured_search_time = None
|
||||||
|
captured_generation_time = None
|
||||||
|
|
||||||
|
def patched_ask(*args, **kwargs):
|
||||||
|
nonlocal captured_search_time, captured_generation_time
|
||||||
|
|
||||||
|
# Time the search part
|
||||||
|
search_start = time.time()
|
||||||
|
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
|
||||||
|
captured_search_time = time.time() - search_start
|
||||||
|
|
||||||
|
# Time the generation part
|
||||||
|
context = "\n\n".join([r.text for r in results])
|
||||||
|
prompt = (
|
||||||
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
|
f"{context}\n\n"
|
||||||
|
f"Question: {args[0]}\n\n"
|
||||||
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_start = time.time()
|
||||||
|
answer = self.chat.llm.ask(prompt)
|
||||||
|
captured_generation_time = time.time() - generation_start
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
self.chat.ask = patched_ask
|
||||||
|
|
||||||
|
# Time the total QA
|
||||||
|
total_start = time.time()
|
||||||
|
generated_answer = self.chat.ask(question)
|
||||||
|
total_time = time.time() - total_start
|
||||||
|
|
||||||
|
# Restore original method
|
||||||
|
self.chat.ask = original_ask
|
||||||
|
|
||||||
|
# Store the timings
|
||||||
|
search_times.append(captured_search_time)
|
||||||
|
generation_times.append(captured_generation_time)
|
||||||
|
total_times.append(total_time)
|
||||||
|
|
||||||
|
# Check accuracy using LLM as judge
|
||||||
|
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
|
||||||
|
if is_correct:
|
||||||
|
correct_answers += 1
|
||||||
|
|
||||||
|
status = "✅" if is_correct else "❌"
|
||||||
|
print(
|
||||||
|
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
|
||||||
|
)
|
||||||
|
print(f" GT: {ground_truth}")
|
||||||
|
print(f" Gen: {generated_answer[:100]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Error: {e}")
|
||||||
|
search_times.append(0.0)
|
||||||
|
generation_times.append(0.0)
|
||||||
|
total_times.append(0.0)
|
||||||
|
|
||||||
|
accuracy = correct_answers / len(data) if data else 0.0
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"total_questions": len(data),
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
|
||||||
|
"avg_generation_time": sum(generation_times) / len(generation_times)
|
||||||
|
if generation_times
|
||||||
|
else 0.0,
|
||||||
|
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"correct_answers": correct_answers,
|
||||||
|
"search_times": search_times,
|
||||||
|
"generation_times": generation_times,
|
||||||
|
"total_times": total_times,
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _check_answer_accuracy(
|
||||||
|
self, generated_answer: str, ground_truth: str, question: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if generated answer matches ground truth using LLM as judge"""
|
||||||
|
judge_prompt = f"""You are an expert judge evaluating financial question answering.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
|
||||||
|
Ground Truth Answer: {ground_truth}
|
||||||
|
|
||||||
|
Generated Answer: {generated_answer}
|
||||||
|
|
||||||
|
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
|
||||||
|
1. Numerical accuracy (exact values, units, currency)
|
||||||
|
2. Key financial concepts and terminology
|
||||||
|
3. Overall factual correctness
|
||||||
|
|
||||||
|
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
|
||||||
|
|
||||||
|
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
judge_response = self.openai_client.chat.completions.create(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
messages=[{"role": "user", "content": judge_prompt}],
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
judgment = judge_response.choices[0].message.content.strip().upper()
|
||||||
|
return judgment == "CORRECT"
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Judge error: {e}, falling back to string matching")
|
||||||
|
# Fallback to simple string matching
|
||||||
|
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
|
||||||
|
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
|
||||||
|
return gt_clean in gen_clean
|
||||||
|
|
||||||
|
def _print_results(self, timing_metrics: dict):
|
||||||
|
"""Print evaluation results"""
|
||||||
|
print("\n🎯 EVALUATION RESULTS")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Index comparison analysis
|
||||||
|
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||||
|
print("\n📏 Index Comparison Analysis:")
|
||||||
|
current = timing_metrics["current_index"]
|
||||||
|
non_compact = timing_metrics["non_compact_index"]
|
||||||
|
|
||||||
|
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
|
||||||
|
print(
|
||||||
|
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(" Component breakdown (non-compact):")
|
||||||
|
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||||
|
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||||
|
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||||
|
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
if "performance_comparison" in timing_metrics:
|
||||||
|
perf = timing_metrics["performance_comparison"]
|
||||||
|
print("\n⚡ Performance Comparison:")
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
# Legacy single index analysis (fallback)
|
||||||
|
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||||
|
print("\n📏 Index Size Analysis:")
|
||||||
|
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
|
||||||
|
|
||||||
|
print("\n📊 Accuracy:")
|
||||||
|
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
|
||||||
|
print(
|
||||||
|
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Timing Breakdown:")
|
||||||
|
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||||
|
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||||
|
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||||
|
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
|
||||||
|
|
||||||
|
if timing_metrics.get("avg_total_time", 0) > 0:
|
||||||
|
search_pct = (
|
||||||
|
timing_metrics.get("avg_search_time", 0)
|
||||||
|
/ timing_metrics.get("avg_total_time", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
gen_pct = (
|
||||||
|
timing_metrics.get("avg_generation_time", 0)
|
||||||
|
/ timing_metrics.get("avg_total_time", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print("\n📈 Time Distribution:")
|
||||||
|
print(f" Search: {search_pct:.1f}%")
|
||||||
|
print(f" Generation: {gen_pct:.1f}%")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if baseline exists
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_financebench.py first to build the baseline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.stage == "2" or args.stage == "all":
|
||||||
|
# Stage 2: Recall@3 evaluation
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load FinanceBench queries for testing
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
queries = []
|
||||||
|
with open(args.dataset, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
|
||||||
|
# Test with more queries for robust measurement
|
||||||
|
test_queries = queries[:2000]
|
||||||
|
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||||
|
|
||||||
|
# Test with complexity 64
|
||||||
|
complexity = 64
|
||||||
|
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
|
||||||
|
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
# Shared non-compact index path for Stage 3 and 4
|
||||||
|
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
if args.stage == "3" or args.stage == "all":
|
||||||
|
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
|
||||||
|
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||||
|
print(
|
||||||
|
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create non-compact index for binary search (will be reused in Stage 4)
|
||||||
|
print("🏗️ Creating non-compact index for binary search...")
|
||||||
|
evaluator = FinanceBenchEvaluator(args.index)
|
||||||
|
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact index for binary search
|
||||||
|
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load queries for testing
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
queries = []
|
||||||
|
with open(args.dataset, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
|
||||||
|
# Use more queries for robust measurement
|
||||||
|
test_queries = queries[:200]
|
||||||
|
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||||
|
|
||||||
|
# Binary search for 90% recall complexity (without recompute for speed)
|
||||||
|
target_recall = 0.9
|
||||||
|
min_complexity, max_complexity = 1, 32
|
||||||
|
|
||||||
|
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||||
|
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||||
|
|
||||||
|
best_complexity = None
|
||||||
|
best_recall = 0.0
|
||||||
|
|
||||||
|
while min_complexity <= max_complexity:
|
||||||
|
mid_complexity = (min_complexity + max_complexity) // 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||||
|
)
|
||||||
|
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries, mid_complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recall >= target_recall:
|
||||||
|
best_complexity = mid_complexity
|
||||||
|
best_recall = recall
|
||||||
|
max_complexity = mid_complexity - 1
|
||||||
|
print(" ✅ Target reached! Searching for lower complexity...")
|
||||||
|
else:
|
||||||
|
min_complexity = mid_complexity + 1
|
||||||
|
print(" ❌ Below target. Searching for higher complexity...")
|
||||||
|
|
||||||
|
if best_complexity is not None:
|
||||||
|
print("\n🎯 Optimal complexity found!")
|
||||||
|
print(f" Complexity: {best_complexity}")
|
||||||
|
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||||
|
|
||||||
|
# Test a few complexities around the optimal one for verification
|
||||||
|
print("\n🔬 Verification test around optimal complexity:")
|
||||||
|
verification_complexities = [
|
||||||
|
max(1, best_complexity - 2),
|
||||||
|
max(1, best_complexity - 1),
|
||||||
|
best_complexity,
|
||||||
|
best_complexity + 1,
|
||||||
|
best_complexity + 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for complexity in verification_complexities:
|
||||||
|
if complexity <= 512: # reasonable upper bound
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries, complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
status = "✅" if recall >= target_recall else "❌"
|
||||||
|
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||||
|
|
||||||
|
# Now test the optimal complexity with compact index and recompute for comparison
|
||||||
|
print(
|
||||||
|
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||||
|
)
|
||||||
|
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries[:10], best_complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||||
|
)
|
||||||
|
complexity = best_complexity
|
||||||
|
print(
|
||||||
|
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||||
|
)
|
||||||
|
compact_evaluator.cleanup()
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||||
|
print("All tested complexities were below target.")
|
||||||
|
|
||||||
|
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||||
|
binary_search_evaluator.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
|
||||||
|
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||||
|
|
||||||
|
if args.stage == "4" or args.stage == "all":
|
||||||
|
# Stage 4: Comprehensive evaluation with dual index comparison
|
||||||
|
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
|
||||||
|
|
||||||
|
# Use FinanceBench evaluator for QA evaluation
|
||||||
|
evaluator = FinanceBenchEvaluator(
|
||||||
|
args.index, args.openai_api_key if args.llm_backend == "openai" else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
data = evaluator.load_dataset(args.dataset)
|
||||||
|
|
||||||
|
# Step 1: Analyze current (compact) index
|
||||||
|
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||||
|
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||||
|
compact_size_metrics["index_type"] = "compact"
|
||||||
|
|
||||||
|
# Step 2: Use existing non-compact index or create if needed
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
print(
|
||||||
|
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||||
|
)
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||||
|
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_size_metrics["index_type"] = "non_compact"
|
||||||
|
else:
|
||||||
|
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||||
|
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||||
|
non_compact_index_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Compare index sizes
|
||||||
|
print("\n📊 Index size comparison:")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||||
|
)
|
||||||
|
print("\n📊 Index-only size comparison (.index file only):")
|
||||||
|
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
|
||||||
|
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
|
||||||
|
# Use index-only size for fair comparison (same as Enron emails)
|
||||||
|
storage_saving = (
|
||||||
|
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
|
||||||
|
/ non_compact_size_metrics["index_only_mb"]
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||||
|
|
||||||
|
# Step 4: Performance comparison between the two indexes
|
||||||
|
if complexity is None:
|
||||||
|
raise ValueError("Complexity is required for performance comparison")
|
||||||
|
|
||||||
|
print("\n⚡ Performance comparison between indexes...")
|
||||||
|
performance_metrics = evaluator.compare_index_performance(
|
||||||
|
non_compact_index_path, args.index, data[:10], complexity=complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Generation evaluation
|
||||||
|
test_samples = 20
|
||||||
|
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
|
||||||
|
|
||||||
|
if args.llm_backend == "openai" and args.openai_api_key:
|
||||||
|
print("🔍🤖 Running OpenAI-based generation evaluation...")
|
||||||
|
evaluation_start = time.time()
|
||||||
|
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
|
||||||
|
evaluation_time = time.time() - evaluation_start
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Load LLM
|
||||||
|
if args.llm_backend == "hf":
|
||||||
|
tokenizer, model = load_hf_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_hf(tokenizer, model, prompt)
|
||||||
|
else: # vllm
|
||||||
|
llm, sampling_params = load_vllm_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_vllm(llm, sampling_params, prompt)
|
||||||
|
|
||||||
|
# Simple generation evaluation
|
||||||
|
queries = [item["question"] for item in data[:test_samples]]
|
||||||
|
gen_results = evaluate_rag(
|
||||||
|
evaluator.searcher,
|
||||||
|
llm_func,
|
||||||
|
queries,
|
||||||
|
domain="finance",
|
||||||
|
complexity=complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
timing_metrics = {
|
||||||
|
"total_questions": len(queries),
|
||||||
|
"avg_search_time": gen_results["avg_search_time"],
|
||||||
|
"avg_generation_time": gen_results["avg_generation_time"],
|
||||||
|
"results": gen_results["results"],
|
||||||
|
}
|
||||||
|
evaluation_time = time.time()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Generation evaluation failed: {e}")
|
||||||
|
timing_metrics = {
|
||||||
|
"total_questions": 0,
|
||||||
|
"avg_search_time": 0,
|
||||||
|
"avg_generation_time": 0,
|
||||||
|
}
|
||||||
|
evaluation_time = 0
|
||||||
|
|
||||||
|
# Combine all metrics
|
||||||
|
combined_metrics = {
|
||||||
|
**timing_metrics,
|
||||||
|
"total_evaluation_time": evaluation_time,
|
||||||
|
"current_index": compact_size_metrics,
|
||||||
|
"non_compact_index": non_compact_size_metrics,
|
||||||
|
"performance_comparison": performance_metrics,
|
||||||
|
"storage_saving_percent": storage_saving,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n📊 Generation Results:")
|
||||||
|
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||||
|
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||||
|
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.output:
|
||||||
|
print(f"\n💾 Saving results to {args.output}...")
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(combined_metrics, f, indent=2, default=str)
|
||||||
|
print(f"✅ Results saved to {args.output}")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage == "all":
|
||||||
|
print("🎉 All evaluation stages completed successfully!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" Stage 2: ✅ Recall@3 evaluation completed")
|
||||||
|
print(" Stage 3: ✅ Optimal complexity found")
|
||||||
|
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
|
||||||
|
print("\n🔧 Recommended next steps:")
|
||||||
|
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||||
|
print(" - Review accuracy and timing breakdown for performance optimization")
|
||||||
|
print(" - Run full evaluation on complete dataset if needed")
|
||||||
|
|
||||||
|
# Clean up non-compact index after all stages complete
|
||||||
|
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
temp_index_dir = Path(non_compact_index_path).parent
|
||||||
|
temp_index_name = Path(non_compact_index_path).name
|
||||||
|
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||||
|
temp_file.unlink()
|
||||||
|
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print("📝 No temporary index to clean up")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Evaluation interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
462
benchmarks/financebench/setup_financebench.py
Executable file
462
benchmarks/financebench/setup_financebench.py
Executable file
@@ -0,0 +1,462 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
FinanceBench Complete Setup Script
|
||||||
|
Downloads all PDFs and builds full LEANN datastore
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
import pymupdf
|
||||||
|
import requests
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class FinanceBenchSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||||
|
self.data_dir = self.base_dir / data_dir
|
||||||
|
self.pdf_dir = self.data_dir / "pdfs"
|
||||||
|
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||||
|
self.index_dir = self.data_dir / "index"
|
||||||
|
self.download_lock = Lock()
|
||||||
|
|
||||||
|
def download_dataset(self):
|
||||||
|
"""Download the main FinanceBench dataset"""
|
||||||
|
print("📊 Downloading FinanceBench dataset...")
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if self.dataset_file.exists():
|
||||||
|
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(self.dataset_file, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||||
|
|
||||||
|
def get_pdf_list(self):
|
||||||
|
"""Get list of all PDF files from GitHub"""
|
||||||
|
print("📋 Fetching PDF list from GitHub...")
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
pdf_files = response.json()
|
||||||
|
|
||||||
|
print(f"Found {len(pdf_files)} PDF files")
|
||||||
|
return pdf_files
|
||||||
|
|
||||||
|
def download_single_pdf(self, pdf_info, position):
|
||||||
|
"""Download a single PDF file"""
|
||||||
|
pdf_name = pdf_info["name"]
|
||||||
|
pdf_path = self.pdf_dir / pdf_name
|
||||||
|
|
||||||
|
# Skip if already downloaded
|
||||||
|
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||||
|
return f"✅ {pdf_name} (cached)"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download PDF
|
||||||
|
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Write to file
|
||||||
|
with self.download_lock:
|
||||||
|
with open(pdf_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ {pdf_name}: {e!s}"
|
||||||
|
|
||||||
|
def download_all_pdfs(self, max_workers: int = 5):
|
||||||
|
"""Download all PDF files with parallel processing"""
|
||||||
|
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pdf_files = self.get_pdf_list()
|
||||||
|
|
||||||
|
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# Submit all download tasks
|
||||||
|
future_to_pdf = {
|
||||||
|
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||||
|
for i, pdf_info in enumerate(pdf_files)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process completed downloads with progress bar
|
||||||
|
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||||
|
for future in as_completed(future_to_pdf):
|
||||||
|
result = future.result()
|
||||||
|
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Verify downloads
|
||||||
|
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||||
|
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||||
|
|
||||||
|
# Show any failures
|
||||||
|
missing_pdfs = []
|
||||||
|
for pdf_info in pdf_files:
|
||||||
|
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||||
|
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||||
|
missing_pdfs.append(pdf_info["name"])
|
||||||
|
|
||||||
|
if missing_pdfs:
|
||||||
|
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||||
|
for pdf in missing_pdfs[:5]: # Show first 5
|
||||||
|
print(f" - {pdf}")
|
||||||
|
if len(missing_pdfs) > 5:
|
||||||
|
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||||
|
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
):
|
||||||
|
"""Build LEANN index from all PDFs"""
|
||||||
|
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
# Check if we have PDFs
|
||||||
|
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||||
|
if not pdf_files:
|
||||||
|
raise RuntimeError("No PDF files found! Run download first.")
|
||||||
|
|
||||||
|
print(f"Found {len(pdf_files)} PDF files to process")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Initialize builder with standard compact configuration
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||||
|
is_compact=True, # Enable compact storage (pruned)
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process PDFs and extract text
|
||||||
|
total_chunks = 0
|
||||||
|
failed_pdfs = []
|
||||||
|
|
||||||
|
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||||
|
try:
|
||||||
|
chunks = self.extract_pdf_text(pdf_path)
|
||||||
|
for chunk in chunks:
|
||||||
|
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||||
|
total_chunks += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||||
|
failed_pdfs.append(pdf_path.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build index in index directory
|
||||||
|
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||||
|
print(f"🔨 Building index: {index_path}")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
print("✅ Index built successfully!")
|
||||||
|
print(f" 📁 Index path: {index_path}")
|
||||||
|
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||||
|
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||||
|
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||||
|
|
||||||
|
if failed_pdfs:
|
||||||
|
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||||
|
|
||||||
|
return str(index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
|
||||||
|
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
|
||||||
|
print("🔨 Building FAISS Flat baseline...")
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read metadata from the built index
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.loads(f.read())
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
print(f"📊 Loading passages from {passage_file}...")
|
||||||
|
print(f"🤖 Using embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
# Load all passages for baseline
|
||||||
|
passages = []
|
||||||
|
passage_ids = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"])
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages")
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
print("🧮 Computing embeddings...")
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Build FAISS flat index
|
||||||
|
print("🏗️ Building FAISS IndexFlatIP...")
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# Add embeddings to flat index
|
||||||
|
embeddings_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||||
|
|
||||||
|
# Save index and metadata
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as f:
|
||||||
|
pickle.dump(passage_ids, f)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved to {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||||
|
"""Extract and chunk text from a PDF file"""
|
||||||
|
chunks = []
|
||||||
|
doc = pymupdf.open(pdf_path)
|
||||||
|
|
||||||
|
for page_num in range(len(doc)):
|
||||||
|
page = doc[page_num]
|
||||||
|
text = page.get_text() # type: ignore
|
||||||
|
|
||||||
|
if not text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"source_file": pdf_path.name,
|
||||||
|
"page_number": page_num + 1,
|
||||||
|
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||||
|
"company": pdf_path.name.split("_")[0],
|
||||||
|
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use recursive character splitting like LangChain
|
||||||
|
if len(text.split()) > 500:
|
||||||
|
# Split by double newlines (paragraphs)
|
||||||
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||||
|
|
||||||
|
current_chunk = ""
|
||||||
|
for para in paragraphs:
|
||||||
|
# If adding this paragraph would make chunk too long, save current chunk
|
||||||
|
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": current_chunk.strip(),
|
||||||
|
"metadata": {
|
||||||
|
**metadata,
|
||||||
|
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
current_chunk = para
|
||||||
|
else:
|
||||||
|
current_chunk = (current_chunk + " " + para).strip()
|
||||||
|
|
||||||
|
# Add the last chunk
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": current_chunk.strip(),
|
||||||
|
"metadata": {
|
||||||
|
**metadata,
|
||||||
|
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Page is short enough, use as single chunk
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": text.strip(),
|
||||||
|
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def extract_year_from_filename(self, filename: str) -> str:
|
||||||
|
"""Extract year from PDF filename"""
|
||||||
|
# Try to find 4-digit year in filename
|
||||||
|
|
||||||
|
match = re.search(r"(\d{4})", filename)
|
||||||
|
return match.group(1) if match else "unknown"
|
||||||
|
|
||||||
|
def verify_setup(self, index_path: str):
|
||||||
|
"""Verify the setup by testing a simple query"""
|
||||||
|
print("🧪 Verifying setup with test query...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Test query
|
||||||
|
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||||
|
results = searcher.search(test_query, top_k=3)
|
||||||
|
|
||||||
|
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
company = result.metadata.get("company", "Unknown")
|
||||||
|
year = result.metadata.get("doc_period", "Unknown")
|
||||||
|
page = result.metadata.get("page_number", "Unknown")
|
||||||
|
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||||
|
print(f" {result.text[:100]}...")
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
print("✅ Setup verification completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Setup verification failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model",
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||||
|
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||||
|
parser.add_argument(
|
||||||
|
"--build-baseline-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only build FAISS baseline from existing index",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("🏦 FinanceBench Complete Setup")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
setup = FinanceBenchSetup(args.data_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.build_baseline_only:
|
||||||
|
# Only build baseline from existing index
|
||||||
|
index_path = setup.index_dir / f"financebench_full_{args.backend}"
|
||||||
|
index_file = f"{index_path}.index"
|
||||||
|
meta_file = f"{index_path}.leann.meta.json"
|
||||||
|
|
||||||
|
if not os.path.exists(index_file) or not os.path.exists(meta_file):
|
||||||
|
print("❌ Index files not found:")
|
||||||
|
print(f" Index: {index_file}")
|
||||||
|
print(f" Meta: {meta_file}")
|
||||||
|
print("💡 Run without --build-baseline-only to build the index first")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print(f"🔨 Building baseline from existing index: {index_path}")
|
||||||
|
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
|
||||||
|
print(f"✅ Baseline built at {baseline_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 1: Download dataset
|
||||||
|
setup.download_dataset()
|
||||||
|
|
||||||
|
# Step 2: Download PDFs
|
||||||
|
if not args.skip_download:
|
||||||
|
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping PDF download")
|
||||||
|
|
||||||
|
# Step 3: Build LEANN index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
backend=args.backend, embedding_model=args.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Build FAISS flat baseline
|
||||||
|
print("\n🔨 Building FAISS flat baseline...")
|
||||||
|
baseline_path = setup.build_faiss_flat_baseline(index_path)
|
||||||
|
print(f"✅ Baseline built at {baseline_path}")
|
||||||
|
|
||||||
|
# Step 5: Verify setup
|
||||||
|
setup.verify_setup(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping index building")
|
||||||
|
|
||||||
|
print("\n🎉 FinanceBench setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("\nNext steps:")
|
||||||
|
print(
|
||||||
|
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Setup interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Setup failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
214
benchmarks/financebench/verify_recall.py
Normal file
214
benchmarks/financebench/verify_recall.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.9"
|
||||||
|
# dependencies = [
|
||||||
|
# "faiss-cpu",
|
||||||
|
# "numpy",
|
||||||
|
# "sentence-transformers",
|
||||||
|
# "torch",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
"""
|
||||||
|
Independent recall verification script using standard FAISS.
|
||||||
|
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Direct embedding computation using sentence-transformers.
|
||||||
|
Copied logic to avoid dependency issues.
|
||||||
|
"""
|
||||||
|
print(f"Loading model: {model_name}")
|
||||||
|
model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
|
print(f"Computing embeddings for {len(chunks)} chunks...")
|
||||||
|
embeddings = model.encode(
|
||||||
|
chunks,
|
||||||
|
show_progress_bar=True,
|
||||||
|
batch_size=32,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
|
||||||
|
"""Load FinanceBench queries from dataset"""
|
||||||
|
queries = []
|
||||||
|
with open(dataset_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
if len(queries) >= max_queries:
|
||||||
|
break
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
|
||||||
|
"""Load passages from LEANN index structure"""
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
passage_file = index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"Loading passages from {passage_file}")
|
||||||
|
|
||||||
|
passages = []
|
||||||
|
passage_ids = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in tqdm(f, desc="Loading passages"):
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"])
|
||||||
|
|
||||||
|
print(f"Loaded {len(passages)} passages")
|
||||||
|
return passages, passage_ids
|
||||||
|
|
||||||
|
|
||||||
|
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
|
||||||
|
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
|
||||||
|
# Build Flat index (ground truth)
|
||||||
|
print("Building FAISS IndexFlatIP (ground truth)...")
|
||||||
|
flat_index = faiss.IndexFlatIP(dimension)
|
||||||
|
flat_index.add(embeddings)
|
||||||
|
|
||||||
|
# Build HNSW index
|
||||||
|
print("Building FAISS IndexHNSWFlat...")
|
||||||
|
M = 32 # Same as LEANN default
|
||||||
|
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
|
||||||
|
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
|
||||||
|
hnsw_index.add(embeddings)
|
||||||
|
|
||||||
|
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
|
||||||
|
return flat_index, hnsw_index
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_recall_at_k(
|
||||||
|
query_embeddings: np.ndarray,
|
||||||
|
flat_index: faiss.Index,
|
||||||
|
hnsw_index: faiss.Index,
|
||||||
|
passage_ids: list[str],
|
||||||
|
k: int = 3,
|
||||||
|
ef_search: int = 64,
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@k comparing HNSW vs Flat"""
|
||||||
|
|
||||||
|
# Set search parameters for HNSW
|
||||||
|
hnsw_index.hnsw.efSearch = ef_search
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = query_embeddings.shape[0]
|
||||||
|
|
||||||
|
for i in range(num_queries):
|
||||||
|
query = query_embeddings[i : i + 1] # Keep 2D shape
|
||||||
|
|
||||||
|
# Get ground truth from Flat index (standard FAISS API)
|
||||||
|
flat_distances, flat_indices = flat_index.search(query, k)
|
||||||
|
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
|
||||||
|
|
||||||
|
# Get results from HNSW index (standard FAISS API)
|
||||||
|
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
|
||||||
|
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
|
||||||
|
|
||||||
|
# Calculate recall
|
||||||
|
intersection = ground_truth_ids.intersection(hnsw_ids)
|
||||||
|
recall = len(intersection) / k
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
|
||||||
|
print(f" Flat: {list(ground_truth_ids)}")
|
||||||
|
print(f" HNSW: {list(hnsw_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Configuration
|
||||||
|
dataset_path = "data/financebench_merged.jsonl"
|
||||||
|
index_path = "data/index/financebench_full_hnsw.leann"
|
||||||
|
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
|
||||||
|
print("🔍 FAISS Recall Verification")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Check if files exist
|
||||||
|
if not Path(dataset_path).exists():
|
||||||
|
print(f"❌ Dataset not found: {dataset_path}")
|
||||||
|
return
|
||||||
|
if not Path(f"{index_path}.meta.json").exists():
|
||||||
|
print(f"❌ Index metadata not found: {index_path}.meta.json")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
print("📖 Loading FinanceBench queries...")
|
||||||
|
queries = load_financebench_queries(dataset_path, max_queries=50)
|
||||||
|
print(f"Loaded {len(queries)} queries")
|
||||||
|
|
||||||
|
print("📄 Loading passages from LEANN index...")
|
||||||
|
passages, passage_ids = load_passages_from_leann_index(index_path)
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
|
print("🧮 Computing passage embeddings...")
|
||||||
|
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
|
||||||
|
|
||||||
|
print("🧮 Computing query embeddings...")
|
||||||
|
query_embeddings = compute_embeddings_direct(queries, embedding_model)
|
||||||
|
|
||||||
|
# Build FAISS indexes
|
||||||
|
print("🏗️ Building FAISS indexes...")
|
||||||
|
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
|
||||||
|
|
||||||
|
# Test different efSearch values (equivalent to LEANN complexity)
|
||||||
|
print("\n📊 Evaluating Recall@3 at different efSearch values...")
|
||||||
|
ef_search_values = [16, 32, 64, 128, 256]
|
||||||
|
|
||||||
|
for ef_search in ef_search_values:
|
||||||
|
print(f"\n🧪 Testing efSearch = {ef_search}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
recall = evaluate_recall_at_k(
|
||||||
|
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(
|
||||||
|
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ Verification completed!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" - Built independent FAISS Flat and HNSW indexes")
|
||||||
|
print(" - Compared recall@3 at different efSearch values")
|
||||||
|
print(" - Used same embedding model as LEANN")
|
||||||
|
print(" - This validates LEANN's recall measurements")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
benchmarks/laion/.gitignore
vendored
Normal file
1
benchmarks/laion/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
data/
|
||||||
199
benchmarks/laion/README.md
Normal file
199
benchmarks/laion/README.md
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# LAION Multimodal Benchmark
|
||||||
|
|
||||||
|
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This benchmark evaluates:
|
||||||
|
- **Image retrieval timing** using caption-based queries
|
||||||
|
- **Recall@K performance** for image search
|
||||||
|
- **Complexity analysis** across different search parameters
|
||||||
|
- **Index size and storage efficiency**
|
||||||
|
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
|
||||||
|
|
||||||
|
## Dataset Configuration
|
||||||
|
|
||||||
|
- **Dataset**: LAION-400M subset (10,000 images)
|
||||||
|
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
|
||||||
|
- **Queries**: 200 random captions from the dataset
|
||||||
|
- **Ground Truth**: Self-recall (query caption → original image)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Setup the benchmark
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd benchmarks/laion
|
||||||
|
python setup_laion.py --num-samples 10000 --num-queries 200
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Create dummy LAION data (10K samples)
|
||||||
|
- Generate CLIP embeddings (512-dim)
|
||||||
|
- Build LEANN index with HNSW backend
|
||||||
|
- Create 200 evaluation queries
|
||||||
|
|
||||||
|
### 2. Run evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all evaluation stages
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann
|
||||||
|
|
||||||
|
# Run specific stages
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
|
||||||
|
|
||||||
|
# Multimodal generation with Qwen2.5-VL
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Save results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --output results.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### Setup Options
|
||||||
|
```bash
|
||||||
|
python setup_laion.py \
|
||||||
|
--num-samples 10000 \
|
||||||
|
--num-queries 200 \
|
||||||
|
--index-path data/laion_index.leann \
|
||||||
|
--backend hnsw
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation Options
|
||||||
|
```bash
|
||||||
|
python evaluate_laion.py \
|
||||||
|
--index data/laion_index.leann \
|
||||||
|
--queries data/evaluation_queries.jsonl \
|
||||||
|
--complexity 64 \
|
||||||
|
--top-k 3 \
|
||||||
|
--num-samples 100 \
|
||||||
|
--stage all
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Stages
|
||||||
|
|
||||||
|
### Stage 2: Recall Evaluation
|
||||||
|
- Evaluates Recall@3 for multimodal retrieval
|
||||||
|
- Compares LEANN vs FAISS baseline performance
|
||||||
|
- Self-recall: query caption should retrieve original image
|
||||||
|
|
||||||
|
### Stage 3: Complexity Analysis
|
||||||
|
- Binary search for optimal complexity (90% recall target)
|
||||||
|
- Tests performance across different complexity levels
|
||||||
|
- Analyzes speed vs. accuracy tradeoffs
|
||||||
|
|
||||||
|
### Stage 4: Index Comparison
|
||||||
|
- Compares compact vs non-compact index sizes
|
||||||
|
- Measures search performance differences
|
||||||
|
- Reports storage efficiency and speed ratios
|
||||||
|
|
||||||
|
### Stage 5: Multimodal Generation
|
||||||
|
- Uses Qwen2.5-VL for image understanding and description
|
||||||
|
- Retrieval-Augmented Generation (RAG) with multimodal context
|
||||||
|
- Measures both search and generation timing
|
||||||
|
|
||||||
|
## Output Metrics
|
||||||
|
|
||||||
|
### Timing Metrics
|
||||||
|
- Average/median/min/max search time
|
||||||
|
- Standard deviation
|
||||||
|
- Searches per second
|
||||||
|
- Latency in milliseconds
|
||||||
|
|
||||||
|
### Recall Metrics
|
||||||
|
- Recall@3 percentage for image retrieval
|
||||||
|
- Number of queries with ground truth
|
||||||
|
|
||||||
|
### Index Metrics
|
||||||
|
- Total index size (MB)
|
||||||
|
- Component breakdown (index, passages, metadata)
|
||||||
|
- Storage savings (compact vs non-compact)
|
||||||
|
- Backend and embedding model info
|
||||||
|
|
||||||
|
### Generation Metrics (Stage 5)
|
||||||
|
- Average search time per query
|
||||||
|
- Average generation time per query
|
||||||
|
- Time distribution (search vs generation)
|
||||||
|
- Sample multimodal responses
|
||||||
|
- Model: Qwen2.5-VL performance
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
|
||||||
|
|
||||||
|
**Stage 3: Optimal Complexity Analysis**
|
||||||
|
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
|
||||||
|
- **Binary Search Range**: 1-128
|
||||||
|
- **Target Recall**: 90%
|
||||||
|
- **Index Type**: Non-compact (for fast binary search)
|
||||||
|
|
||||||
|
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
|
||||||
|
- **Total Queries**: 20
|
||||||
|
- **Average Search Time**: 1.200s per query
|
||||||
|
- **Average Generation Time**: 6.558s per query
|
||||||
|
- **Time Distribution**: Search 15.5%, Generation 84.5%
|
||||||
|
- **LLM Backend**: HuggingFace transformers
|
||||||
|
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
- **Optimal Complexity**: 85
|
||||||
|
|
||||||
|
**System Performance:**
|
||||||
|
- **Index Size**: ~10,000 image embeddings from LAION subset
|
||||||
|
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
|
||||||
|
- **Backend**: HNSW with cosine distance
|
||||||
|
|
||||||
|
### Example Results
|
||||||
|
|
||||||
|
```
|
||||||
|
🎯 LAION MULTIMODAL BENCHMARK RESULTS
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
📊 Multimodal Generation Results:
|
||||||
|
Total Queries: 20
|
||||||
|
Avg Search Time: 1.200s
|
||||||
|
Avg Generation Time: 6.558s
|
||||||
|
Time Distribution: Search 15.5%, Generation 84.5%
|
||||||
|
LLM Backend: HuggingFace transformers
|
||||||
|
Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
|
||||||
|
⚙️ Optimal Complexity Analysis:
|
||||||
|
Target Recall: 90%
|
||||||
|
Optimal Complexity: 85
|
||||||
|
Binary Search Range: 1-128
|
||||||
|
Non-compact Index (fast search, no recompute)
|
||||||
|
|
||||||
|
🚀 Performance Summary:
|
||||||
|
Multimodal RAG: 7.758s total per query
|
||||||
|
Search: 15.5% of total time
|
||||||
|
Generation: 84.5% of total time
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
benchmarks/laion/
|
||||||
|
├── setup_laion.py # Setup script
|
||||||
|
├── evaluate_laion.py # Evaluation script
|
||||||
|
├── README.md # This file
|
||||||
|
└── data/ # Generated data
|
||||||
|
├── laion_images/ # Image files (placeholder)
|
||||||
|
├── laion_metadata.jsonl # Image metadata
|
||||||
|
├── laion_passages.jsonl # LEANN passages
|
||||||
|
├── laion_embeddings.npy # CLIP embeddings
|
||||||
|
├── evaluation_queries.jsonl # Evaluation queries
|
||||||
|
└── laion_index.leann/ # LEANN index files
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Current implementation uses dummy data for demonstration
|
||||||
|
- For real LAION data, implement actual download logic in `setup_laion.py`
|
||||||
|
- CLIP embeddings are randomly generated - replace with real CLIP model for production
|
||||||
|
- Adjust `num_samples` and `num_queries` based on available resources
|
||||||
|
- Consider using `--num-samples` during evaluation for faster testing
|
||||||
725
benchmarks/laion/evaluate_laion.py
Normal file
725
benchmarks/laion/evaluate_laion.py
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
"""
|
||||||
|
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Load FAISS flat baseline (image embeddings)
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.image_ids = pickle.load(f)
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
|
||||||
|
|
||||||
|
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
|
||||||
|
self.st_clip = SentenceTransformer("clip-ViT-L-14")
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = len(captions)
|
||||||
|
|
||||||
|
for i, caption in enumerate(captions):
|
||||||
|
# Get ground truth: search with FAISS flat using caption text embedding
|
||||||
|
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
|
||||||
|
query_embedding = self.st_clip.encode(
|
||||||
|
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||||
|
n = query_embedding.shape[0] # Number of queries
|
||||||
|
k = 3 # Number of nearest neighbors
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(query_embedding),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the results (image IDs from FAISS)
|
||||||
|
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN at specified complexity (using caption as text query)
|
||||||
|
test_results = self.searcher.search(
|
||||||
|
caption,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {result.id for result in test_results}
|
||||||
|
|
||||||
|
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class LAIONEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
"""Load caption queries from evaluation file"""
|
||||||
|
captions = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
query_data = json.loads(line)
|
||||||
|
captions.append(query_data["query"])
|
||||||
|
|
||||||
|
print(f"📊 Loaded {len(captions)} caption queries")
|
||||||
|
return captions
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
|
||||||
|
# Get all index-related files
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem # Remove .leann extension
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
|
||||||
|
# Core index files
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||||
|
|
||||||
|
# Core index size (.index only)
|
||||||
|
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
sizes["index_only_mb"] = index_mb
|
||||||
|
|
||||||
|
# Other files for reference (not counted in index_only_mb)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {index_mb:.1f} MB")
|
||||||
|
if sizes["metadata_mb"]:
|
||||||
|
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
|
||||||
|
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
|
||||||
|
print(
|
||||||
|
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building non-compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Load CLIP embeddings
|
||||||
|
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
|
||||||
|
embeddings = np.load(embeddings_file)
|
||||||
|
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=False, # Disable recompute (store embeddings)
|
||||||
|
is_compact=False, # Disable compact storage
|
||||||
|
distance_metric="cosine",
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact", "distance_metric"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare ids and add passages
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
# Ensure metadata contains the id used by the vector index
|
||||||
|
metadata = {**data.get("metadata", {}), "id": data["id"]}
|
||||||
|
builder.add_text(text=data["text"], metadata=metadata)
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare performance between non-compact and compact indexes"""
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
|
||||||
|
# Test queries
|
||||||
|
test_queries = test_captions[:5]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test non-compact index (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
|
||||||
|
for caption in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
caption, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["non_compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Test compact index (with recompute)
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
|
||||||
|
for caption in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
caption, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance ratio
|
||||||
|
if results["avg_search_times"]["compact"] > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||||
|
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _print_results(self, timing_metrics: dict):
|
||||||
|
"""Print evaluation results"""
|
||||||
|
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Index comparison analysis (prefer .index-only view if present)
|
||||||
|
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||||
|
current = timing_metrics["current_index"]
|
||||||
|
non_compact = timing_metrics["non_compact_index"]
|
||||||
|
|
||||||
|
if "index_only_mb" in current and "index_only_mb" in non_compact:
|
||||||
|
print("\n📏 Index Comparison Analysis (.index only):")
|
||||||
|
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
|
||||||
|
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
# Show excluded components for reference if available
|
||||||
|
if any(
|
||||||
|
k in non_compact
|
||||||
|
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
|
||||||
|
):
|
||||||
|
print(" (passages excluded in totals, shown for reference):")
|
||||||
|
print(
|
||||||
|
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
|
||||||
|
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
|
||||||
|
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to legacy totals if running with older metrics
|
||||||
|
print("\n📏 Index Comparison Analysis:")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
print(" Component breakdown (non-compact):")
|
||||||
|
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||||
|
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||||
|
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||||
|
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
if "performance_comparison" in timing_metrics:
|
||||||
|
perf = timing_metrics["performance_comparison"]
|
||||||
|
print("\n⚡ Performance Comparison:")
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
# Legacy single index analysis (fallback)
|
||||||
|
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||||
|
print("\n📏 Index Size Analysis:")
|
||||||
|
print(
|
||||||
|
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "5", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend",
|
||||||
|
choices=["hf"],
|
||||||
|
default="hf",
|
||||||
|
help="LLM backend (Qwen2.5-VL only supports HF)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if baseline exists
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_laion.py first to build the baseline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.stage == "2" or args.stage == "all":
|
||||||
|
# Stage 2: Recall@3 evaluation
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
|
||||||
|
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load caption queries for testing
|
||||||
|
laion_evaluator = LAIONEvaluator(args.index)
|
||||||
|
captions = laion_evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Test with queries for robust measurement
|
||||||
|
test_captions = captions[:100] # Use subset for speed
|
||||||
|
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||||
|
|
||||||
|
# Test with complexity 64
|
||||||
|
complexity = 64
|
||||||
|
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
|
||||||
|
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
# Shared non-compact index path for Stage 3 and 4
|
||||||
|
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
if args.stage == "3" or args.stage == "all":
|
||||||
|
# Stage 3: Binary search for 90% recall complexity
|
||||||
|
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||||
|
print(
|
||||||
|
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create non-compact index for binary search
|
||||||
|
print("🏗️ Creating non-compact index for binary search...")
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact index for binary search
|
||||||
|
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load caption queries for testing
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Use subset for robust measurement
|
||||||
|
test_captions = captions[:50] # Smaller subset for binary search speed
|
||||||
|
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||||
|
|
||||||
|
# Binary search for 90% recall complexity
|
||||||
|
target_recall = 0.9
|
||||||
|
min_complexity, max_complexity = 1, 128
|
||||||
|
|
||||||
|
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||||
|
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||||
|
|
||||||
|
best_complexity = None
|
||||||
|
best_recall = 0.0
|
||||||
|
|
||||||
|
while min_complexity <= max_complexity:
|
||||||
|
mid_complexity = (min_complexity + max_complexity) // 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||||
|
)
|
||||||
|
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions, mid_complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recall >= target_recall:
|
||||||
|
best_complexity = mid_complexity
|
||||||
|
best_recall = recall
|
||||||
|
max_complexity = mid_complexity - 1
|
||||||
|
print(" ✅ Target reached! Searching for lower complexity...")
|
||||||
|
else:
|
||||||
|
min_complexity = mid_complexity + 1
|
||||||
|
print(" ❌ Below target. Searching for higher complexity...")
|
||||||
|
|
||||||
|
if best_complexity is not None:
|
||||||
|
print("\n🎯 Optimal complexity found!")
|
||||||
|
print(f" Complexity: {best_complexity}")
|
||||||
|
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||||
|
|
||||||
|
# Test a few complexities around the optimal one for verification
|
||||||
|
print("\n🔬 Verification test around optimal complexity:")
|
||||||
|
verification_complexities = [
|
||||||
|
max(1, best_complexity - 2),
|
||||||
|
max(1, best_complexity - 1),
|
||||||
|
best_complexity,
|
||||||
|
best_complexity + 1,
|
||||||
|
best_complexity + 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for complexity in verification_complexities:
|
||||||
|
if complexity <= 512: # reasonable upper bound
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions, complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
status = "✅" if recall >= target_recall else "❌"
|
||||||
|
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||||
|
|
||||||
|
# Now test the optimal complexity with compact index and recompute for comparison
|
||||||
|
print(
|
||||||
|
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||||
|
)
|
||||||
|
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions[:10], best_complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||||
|
)
|
||||||
|
complexity = best_complexity
|
||||||
|
print(
|
||||||
|
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||||
|
)
|
||||||
|
compact_evaluator.cleanup()
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||||
|
print("All tested complexities were below target.")
|
||||||
|
|
||||||
|
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||||
|
binary_search_evaluator.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
|
||||||
|
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||||
|
|
||||||
|
if args.stage == "4" or args.stage == "all":
|
||||||
|
# Stage 4: Index comparison (without LLM generation)
|
||||||
|
print("🚀 Starting Stage 4: Index comparison analysis")
|
||||||
|
|
||||||
|
# Use LAION evaluator for index comparison
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
|
||||||
|
# Load caption queries
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Step 1: Analyze current (compact) index
|
||||||
|
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||||
|
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||||
|
compact_size_metrics["index_type"] = "compact"
|
||||||
|
|
||||||
|
# Step 2: Use existing non-compact index or create if needed
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
print(
|
||||||
|
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||||
|
)
|
||||||
|
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||||
|
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_size_metrics["index_type"] = "non_compact"
|
||||||
|
else:
|
||||||
|
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||||
|
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||||
|
non_compact_index_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Compare index sizes (.index only)
|
||||||
|
print("\n📊 Index size comparison (.index only):")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
|
||||||
|
|
||||||
|
storage_saving = 0.0
|
||||||
|
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
|
||||||
|
storage_saving = (
|
||||||
|
(
|
||||||
|
non_compact_size_metrics.get("index_only_mb", 0)
|
||||||
|
- compact_size_metrics.get("index_only_mb", 0)
|
||||||
|
)
|
||||||
|
/ non_compact_size_metrics.get("index_only_mb", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||||
|
|
||||||
|
# Step 4: Performance comparison between the two indexes
|
||||||
|
if complexity is None:
|
||||||
|
raise ValueError("Complexity is required for index comparison")
|
||||||
|
|
||||||
|
print("\n⚡ Performance comparison between indexes...")
|
||||||
|
performance_metrics = evaluator.compare_index_performance(
|
||||||
|
non_compact_index_path, args.index, captions[:10], complexity=complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all metrics
|
||||||
|
combined_metrics = {
|
||||||
|
"current_index": compact_size_metrics,
|
||||||
|
"non_compact_index": non_compact_size_metrics,
|
||||||
|
"performance_comparison": performance_metrics,
|
||||||
|
"storage_saving_percent": storage_saving,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print comprehensive results
|
||||||
|
evaluator._print_results(combined_metrics)
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.output:
|
||||||
|
print(f"\n💾 Saving results to {args.output}...")
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(combined_metrics, f, indent=2, default=str)
|
||||||
|
print(f"✅ Results saved to {args.output}")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("5", "all"):
|
||||||
|
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
test_captions = captions[: min(20, len(captions))] # Use subset for generation
|
||||||
|
|
||||||
|
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
|
||||||
|
|
||||||
|
# Load Qwen2.5-VL model
|
||||||
|
try:
|
||||||
|
print("Loading Qwen2.5-VL model...")
|
||||||
|
processor, model = load_qwen_vl_model(args.model_name)
|
||||||
|
|
||||||
|
# Run multimodal generation evaluation
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
gen_results = evaluate_multimodal_rag(
|
||||||
|
evaluator.searcher,
|
||||||
|
test_captions,
|
||||||
|
processor=processor,
|
||||||
|
model=model,
|
||||||
|
complexity=complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Multimodal Generation Results:")
|
||||||
|
print(f" Total Queries: {len(test_captions)}")
|
||||||
|
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
|
||||||
|
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
|
||||||
|
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
|
||||||
|
search_pct = (gen_results["avg_search_time"] / total_time) * 100
|
||||||
|
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
|
||||||
|
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
|
||||||
|
print(" LLM Backend: HuggingFace transformers")
|
||||||
|
print(f" Model: {args.model_name}")
|
||||||
|
|
||||||
|
# Show sample results
|
||||||
|
print("\n📝 Sample Multimodal Generations:")
|
||||||
|
for i, response in enumerate(gen_results["results"][:3]):
|
||||||
|
# Handle both string and dict formats for captions
|
||||||
|
if isinstance(test_captions[i], dict):
|
||||||
|
caption_text = test_captions[i].get("query", str(test_captions[i]))
|
||||||
|
else:
|
||||||
|
caption_text = str(test_captions[i])
|
||||||
|
print(f" Query {i + 1}: {caption_text[:60]}...")
|
||||||
|
print(f" Response {i + 1}: {response[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Multimodal generation evaluation failed: {e}")
|
||||||
|
print("💡 Make sure transformers and Qwen2.5-VL are installed")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 5 completed!\n")
|
||||||
|
|
||||||
|
if args.stage == "all":
|
||||||
|
print("🎉 All evaluation stages completed successfully!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
|
||||||
|
print(" Stage 3: ✅ Optimal complexity found")
|
||||||
|
print(" Stage 4: ✅ Index comparison analysis completed")
|
||||||
|
print(" Stage 5: ✅ Multimodal generation evaluation completed")
|
||||||
|
print("\n🔧 Recommended next steps:")
|
||||||
|
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||||
|
print(" - Review index comparison for storage vs performance tradeoffs")
|
||||||
|
|
||||||
|
# Clean up non-compact index after all stages complete
|
||||||
|
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
temp_index_dir = Path(non_compact_index_path).parent
|
||||||
|
temp_index_name = Path(non_compact_index_path).name
|
||||||
|
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||||
|
temp_file.unlink()
|
||||||
|
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print("📝 No temporary index to clean up")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Evaluation interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
576
benchmarks/laion/setup_laion.py
Normal file
576
benchmarks/laion/setup_laion.py
Normal file
@@ -0,0 +1,576 @@
|
|||||||
|
"""
|
||||||
|
LAION Multimodal Benchmark Setup Script
|
||||||
|
Downloads LAION subset and builds LEANN index with sentence embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from datasets import load_dataset
|
||||||
|
from leann import LeannBuilder
|
||||||
|
from PIL import Image
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class LAIONSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.images_dir = self.data_dir / "laion_images"
|
||||||
|
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
|
||||||
|
|
||||||
|
# Create directories
|
||||||
|
self.data_dir.mkdir(exist_ok=True)
|
||||||
|
self.images_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
|
||||||
|
"""Download a single image asynchronously"""
|
||||||
|
async with semaphore: # Limit concurrent downloads
|
||||||
|
try:
|
||||||
|
image_url = sample_data["url"]
|
||||||
|
image_path = sample_data["image_path"]
|
||||||
|
|
||||||
|
# Skip if already exists
|
||||||
|
if os.path.exists(image_path):
|
||||||
|
progress_bar.update(1)
|
||||||
|
return sample_data
|
||||||
|
|
||||||
|
async with session.get(image_url, timeout=10) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
|
||||||
|
# Verify it's a valid image
|
||||||
|
try:
|
||||||
|
img = Image.open(io.BytesIO(content))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.save(image_path, "JPEG")
|
||||||
|
progress_bar.update(1)
|
||||||
|
return sample_data
|
||||||
|
except Exception:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None # Skip invalid images
|
||||||
|
else:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def download_laion_subset(self, num_samples: int = 1000):
|
||||||
|
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
|
||||||
|
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
|
||||||
|
|
||||||
|
# Load LAION-400M subset from HuggingFace
|
||||||
|
print("🤗 Loading from HuggingFace datasets...")
|
||||||
|
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
|
||||||
|
|
||||||
|
# Collect sample metadata first (fast)
|
||||||
|
print("📋 Collecting sample metadata...")
|
||||||
|
candidates = []
|
||||||
|
for sample in dataset:
|
||||||
|
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
|
||||||
|
break
|
||||||
|
|
||||||
|
image_url = sample.get("url", "")
|
||||||
|
caption = sample.get("caption", "")
|
||||||
|
|
||||||
|
if not image_url or not caption:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_filename = f"laion_{len(candidates):06d}.jpg"
|
||||||
|
image_path = self.images_dir / image_filename
|
||||||
|
|
||||||
|
candidate = {
|
||||||
|
"id": f"laion_{len(candidates):06d}",
|
||||||
|
"url": image_url,
|
||||||
|
"caption": caption,
|
||||||
|
"image_path": str(image_path),
|
||||||
|
"width": sample.get("original_width", 512),
|
||||||
|
"height": sample.get("original_height", 512),
|
||||||
|
"similarity": sample.get("similarity", 0.0),
|
||||||
|
}
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download images in parallel
|
||||||
|
async def download_batch():
|
||||||
|
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
|
||||||
|
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30)
|
||||||
|
|
||||||
|
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||||
|
tasks = []
|
||||||
|
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
|
||||||
|
task = self.download_single_image(session, candidate, semaphore, progress_bar)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all downloads
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
# Filter successful downloads
|
||||||
|
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
|
||||||
|
return successful[:num_samples]
|
||||||
|
|
||||||
|
# Run async download
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
samples = loop.run_until_complete(download_batch())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Save metadata
|
||||||
|
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
||||||
|
for sample in samples:
|
||||||
|
f.write(json.dumps(sample) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def generate_clip_image_embeddings(self, samples: list[dict]):
|
||||||
|
"""Generate CLIP image embeddings for downloaded images"""
|
||||||
|
print("🔍 Generating CLIP image embeddings...")
|
||||||
|
|
||||||
|
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
|
||||||
|
# This single model can encode both images and text.
|
||||||
|
model = SentenceTransformer("clip-ViT-L-14")
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
valid_samples = []
|
||||||
|
|
||||||
|
for sample in tqdm(samples, desc="Processing images"):
|
||||||
|
try:
|
||||||
|
# Load image
|
||||||
|
image_path = sample["image_path"]
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# Encode image to 768-dim embedding via sentence-transformers (normalized)
|
||||||
|
vec = model.encode(
|
||||||
|
[image],
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=1,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)[0]
|
||||||
|
embeddings.append(vec.astype(np.float32))
|
||||||
|
valid_samples.append(sample)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Failed to process {sample['id']}: {e}")
|
||||||
|
# Skip invalid images
|
||||||
|
|
||||||
|
embeddings = np.array(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Save embeddings
|
||||||
|
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
np.save(embeddings_file, embeddings)
|
||||||
|
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
return embeddings, valid_samples
|
||||||
|
|
||||||
|
def build_faiss_baseline(
|
||||||
|
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
|
||||||
|
):
|
||||||
|
"""Build FAISS flat baseline using CLIP image embeddings"""
|
||||||
|
print("🔨 Building FAISS Flat baseline...")
|
||||||
|
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Extract image IDs (must be present)
|
||||||
|
if not samples or "id" not in samples[0]:
|
||||||
|
raise KeyError("samples missing 'id' field for FAISS baseline")
|
||||||
|
image_ids: list[str] = [str(sample["id"]) for sample in samples]
|
||||||
|
|
||||||
|
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||||
|
print(f"📄 Processing {len(image_ids)} images")
|
||||||
|
|
||||||
|
# Build FAISS flat index
|
||||||
|
print("🏗️ Building FAISS IndexFlatIP...")
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# Add embeddings to flat index
|
||||||
|
embeddings_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||||
|
|
||||||
|
# Save index and metadata
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as f:
|
||||||
|
pickle.dump(image_ids, f)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved to {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
def create_leann_passages(self, samples: list[dict]):
|
||||||
|
"""Create LEANN-compatible passages from LAION data"""
|
||||||
|
print("📝 Creating LEANN passages...")
|
||||||
|
|
||||||
|
passages_file = self.data_dir / "laion_passages.jsonl"
|
||||||
|
|
||||||
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
passage = {
|
||||||
|
"id": sample["id"],
|
||||||
|
"text": sample["caption"], # Use caption as searchable text
|
||||||
|
"metadata": {
|
||||||
|
"image_url": sample["url"],
|
||||||
|
"image_path": sample.get("image_path", ""),
|
||||||
|
"width": sample["width"],
|
||||||
|
"height": sample["height"],
|
||||||
|
"similarity": sample["similarity"],
|
||||||
|
"image_index": i, # Index for embedding lookup
|
||||||
|
},
|
||||||
|
}
|
||||||
|
f.write(json.dumps(passage) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Created {len(samples)} passages")
|
||||||
|
return passages_file
|
||||||
|
|
||||||
|
def build_compact_index(
|
||||||
|
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
|
||||||
|
print(f"🏗️ Building compact LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
|
||||||
|
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
np.save(npy_path, embeddings)
|
||||||
|
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||||
|
|
||||||
|
# Prepare ids in the same order as passages_file (matches embeddings order)
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
rec = json.loads(line)
|
||||||
|
ids.append(str(rec["id"]))
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||||
|
|
||||||
|
# Initialize builder - compact with recompute
|
||||||
|
# Note: For multimodal case, we need to handle embeddings differently
|
||||||
|
# Let's try using sentence-transformers mode but with custom embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
# HNSW params (or forwarded to chosen backend)
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
# Compact/pruned with recompute at query time
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add passages (text + metadata)
|
||||||
|
print("📚 Adding passages...")
|
||||||
|
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||||
|
|
||||||
|
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
|
||||||
|
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
print(f"✅ Compact index built in {build_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze index size
|
||||||
|
self._analyze_index_size(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def build_non_compact_index(
|
||||||
|
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
|
||||||
|
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Ensure embeddings are saved (npy + pickle)
|
||||||
|
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
if not npy_path.exists():
|
||||||
|
np.save(npy_path, embeddings)
|
||||||
|
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||||
|
# Prepare ids in same order as passages_file
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
rec = json.loads(line)
|
||||||
|
ids.append(str(rec["id"]))
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||||
|
if not pkl_path.exists():
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||||
|
|
||||||
|
# Initialize builder - non-compact without recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=False, # Store embeddings (no recompute needed)
|
||||||
|
is_compact=False, # Store full index (not pruned)
|
||||||
|
distance_metric="cosine",
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add passages - embeddings will be loaded from file
|
||||||
|
print("📚 Adding passages...")
|
||||||
|
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||||
|
|
||||||
|
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
|
||||||
|
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
print(f"✅ Non-compact index built in {build_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze index size
|
||||||
|
self._analyze_index_size(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
|
||||||
|
"""Helper to add passages with pre-computed CLIP embeddings"""
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in tqdm(f, desc="Adding passages"):
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
|
||||||
|
# Add image metadata - LEANN will handle embeddings separately
|
||||||
|
# Note: We store image metadata and caption text for searchability
|
||||||
|
# Important: ensure passage ID in metadata matches vector ID
|
||||||
|
builder.add_text(
|
||||||
|
text=passage["text"], # Image caption for searchability
|
||||||
|
metadata={**passage["metadata"], "id": passage["id"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_index_size(self, index_path: str):
|
||||||
|
"""Analyze index file sizes"""
|
||||||
|
print("📏 Analyzing index sizes...")
|
||||||
|
|
||||||
|
index_path = Path(index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.name # e.g., laion_index.leann
|
||||||
|
index_prefix = index_path.stem # e.g., laion_index
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(f"{index_prefix}.index", ".index", "core"),
|
||||||
|
(f"{index_name}.meta.json", ".meta.json", "core"),
|
||||||
|
(f"{index_name}.ids.txt", ".ids.txt", "core"),
|
||||||
|
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
|
||||||
|
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _fmt_size(bytes_val: int) -> str:
|
||||||
|
if bytes_val < 1024:
|
||||||
|
return f"{bytes_val} B"
|
||||||
|
kb = bytes_val / 1024
|
||||||
|
if kb < 1024:
|
||||||
|
return f"{kb:.1f} KB"
|
||||||
|
mb = kb / 1024
|
||||||
|
if mb < 1024:
|
||||||
|
return f"{mb:.2f} MB"
|
||||||
|
gb = mb / 1024
|
||||||
|
return f"{gb:.2f} GB"
|
||||||
|
|
||||||
|
total_index_only_mb = 0.0
|
||||||
|
total_all_mb = 0.0
|
||||||
|
for filename, label, group in files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if file_path.exists():
|
||||||
|
size_bytes = file_path.stat().st_size
|
||||||
|
print(f" {label}: {_fmt_size(size_bytes)}")
|
||||||
|
size_mb = size_bytes / (1024 * 1024)
|
||||||
|
total_all_mb += size_mb
|
||||||
|
if group == "core":
|
||||||
|
total_index_only_mb += size_mb
|
||||||
|
else:
|
||||||
|
print(f" {label}: (missing)")
|
||||||
|
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
|
||||||
|
print(f" Total (including passages): {total_all_mb:.2f} MB")
|
||||||
|
|
||||||
|
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
|
||||||
|
"""Create evaluation queries from captions"""
|
||||||
|
print(f"📝 Creating {num_queries} evaluation queries...")
|
||||||
|
|
||||||
|
# Sample random captions as queries
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(42) # For reproducibility
|
||||||
|
|
||||||
|
query_samples = random.sample(samples, min(num_queries, len(samples)))
|
||||||
|
|
||||||
|
queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
with open(queries_file, "w", encoding="utf-8") as f:
|
||||||
|
for sample in query_samples:
|
||||||
|
query = {
|
||||||
|
"id": sample["id"],
|
||||||
|
"query": sample["caption"],
|
||||||
|
"ground_truth_id": sample["id"], # For potential recall evaluation
|
||||||
|
}
|
||||||
|
f.write(json.dumps(query) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Created {len(query_samples)} evaluation queries")
|
||||||
|
return queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
|
||||||
|
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
|
||||||
|
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("🚀 Setting up LAION Multimodal Benchmark")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize setup
|
||||||
|
setup = LAIONSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Step 1: Download LAION subset
|
||||||
|
if not args.skip_download:
|
||||||
|
print("\n📦 Step 1: Download LAION subset")
|
||||||
|
samples = setup.download_laion_subset(args.num_samples)
|
||||||
|
|
||||||
|
# Step 2: Generate CLIP image embeddings
|
||||||
|
print("\n🔍 Step 2: Generate CLIP image embeddings")
|
||||||
|
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
|
||||||
|
|
||||||
|
# Step 3: Create LEANN passages (image metadata with embeddings)
|
||||||
|
print("\n📝 Step 3: Create LEANN passages")
|
||||||
|
passages_file = setup.create_leann_passages(valid_samples)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LAION dataset download")
|
||||||
|
# Load existing data
|
||||||
|
passages_file = setup.data_dir / "laion_passages.jsonl"
|
||||||
|
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not embeddings_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passages or embeddings file not found. Run without --skip-download first."
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.load(embeddings_file)
|
||||||
|
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
|
||||||
|
|
||||||
|
# Step 4: Build LEANN indexes (both compact and non-compact)
|
||||||
|
if not args.skip_build:
|
||||||
|
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
|
||||||
|
|
||||||
|
# Build compact index (production mode - small, recompute required)
|
||||||
|
compact_index_path = args.index_path
|
||||||
|
print(f"Building compact index: {compact_index_path}")
|
||||||
|
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
|
||||||
|
|
||||||
|
# Build non-compact index (comparison mode - large, fast search)
|
||||||
|
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
|
||||||
|
print(f"Building non-compact index: {non_compact_index_path}")
|
||||||
|
setup.build_non_compact_index(
|
||||||
|
passages_file, embeddings, non_compact_index_path, args.backend
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Build FAISS flat baseline
|
||||||
|
print("\n🔨 Step 5: Build FAISS flat baseline")
|
||||||
|
if not args.skip_download:
|
||||||
|
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||||
|
else:
|
||||||
|
# Load valid_samples from passages file for FAISS baseline
|
||||||
|
valid_samples = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
|
||||||
|
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||||
|
|
||||||
|
# Step 6: Create evaluation queries
|
||||||
|
print("\n📝 Step 6: Create evaluation queries")
|
||||||
|
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping index building")
|
||||||
|
baseline_path = "data/baseline/faiss_index.bin"
|
||||||
|
queries_file = setup.data_dir / "evaluation_queries.jsonl"
|
||||||
|
|
||||||
|
print("\n🎉 Setup completed successfully!")
|
||||||
|
print("📊 Summary:")
|
||||||
|
if not args.skip_download:
|
||||||
|
print(f" Downloaded samples: {len(samples)}")
|
||||||
|
print(f" Valid samples with embeddings: {len(valid_samples)}")
|
||||||
|
else:
|
||||||
|
print(f" Loaded {len(embeddings)} embeddings")
|
||||||
|
|
||||||
|
if not args.skip_build:
|
||||||
|
print(f" Compact index: {compact_index_path}")
|
||||||
|
print(f" Non-compact index: {non_compact_index_path}")
|
||||||
|
print(f" FAISS baseline: {baseline_path}")
|
||||||
|
print(f" Queries: {queries_file}")
|
||||||
|
|
||||||
|
print("\n🔧 Next steps:")
|
||||||
|
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
|
||||||
|
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print(" Skipped building indexes")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Setup interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Setup failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
301
benchmarks/llm_utils.py
Normal file
301
benchmarks/llm_utils.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
"""
|
||||||
|
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
HF_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
HF_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
VLLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_qwen3_model(model_name):
|
||||||
|
"""Check if model is Qwen3"""
|
||||||
|
return "Qwen3" in model_name or "qwen3" in model_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def is_qwen_vl_model(model_name):
|
||||||
|
"""Check if model is Qwen2.5-VL"""
|
||||||
|
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_qwen3_chat_template(tokenizer, prompt):
|
||||||
|
"""Apply Qwen3 chat template with thinking enabled"""
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_thinking_answer(response):
|
||||||
|
"""Extract final answer from Qwen3 thinking model response"""
|
||||||
|
if "<think>" in response and "</think>" in response:
|
||||||
|
try:
|
||||||
|
think_end = response.index("</think>") + len("</think>")
|
||||||
|
final_answer = response[think_end:].strip()
|
||||||
|
return final_answer
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_model(model_name="Qwen/Qwen3-8B"):
|
||||||
|
"""Load HuggingFace model"""
|
||||||
|
if not HF_AVAILABLE:
|
||||||
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
|
print(f"Loading HF: {model_name}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def load_vllm_model(model_name="Qwen/Qwen3-8B"):
|
||||||
|
"""Load vLLM model"""
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
raise ImportError("vllm not available")
|
||||||
|
|
||||||
|
print(f"Loading vLLM: {model_name}")
|
||||||
|
llm = LLM(model=model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Qwen3 specific config
|
||||||
|
if is_qwen3_model(model_name):
|
||||||
|
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
|
||||||
|
max_tokens = 2048
|
||||||
|
else:
|
||||||
|
stop_tokens = None
|
||||||
|
max_tokens = 1024
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
|
||||||
|
return llm, sampling_params
|
||||||
|
|
||||||
|
|
||||||
|
def generate_hf(tokenizer, model, prompt, max_tokens=None):
|
||||||
|
"""Generate with HF - supports Qwen3 thinking models"""
|
||||||
|
model_name = getattr(model, "name_or_path", "unknown")
|
||||||
|
is_qwen3 = is_qwen3_model(model_name)
|
||||||
|
|
||||||
|
# Apply chat template for Qwen3
|
||||||
|
if is_qwen3:
|
||||||
|
prompt = apply_qwen3_chat_template(tokenizer, prompt)
|
||||||
|
max_tokens = max_tokens or 2048
|
||||||
|
else:
|
||||||
|
max_tokens = max_tokens or 1024
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
temperature=0.7,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
response = response[len(prompt) :].strip()
|
||||||
|
|
||||||
|
# Extract final answer for thinking models
|
||||||
|
if is_qwen3:
|
||||||
|
return extract_thinking_answer(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vllm(llm, sampling_params, prompt):
|
||||||
|
"""Generate with vLLM - supports Qwen3 thinking models"""
|
||||||
|
outputs = llm.generate([prompt], sampling_params)
|
||||||
|
response = outputs[0].outputs[0].text.strip()
|
||||||
|
|
||||||
|
# Extract final answer for Qwen3 thinking models
|
||||||
|
model_name = str(llm.llm_engine.model_config.model)
|
||||||
|
if is_qwen3_model(model_name):
|
||||||
|
return extract_thinking_answer(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_prompt(context, query, domain="default"):
|
||||||
|
"""Create RAG prompt"""
|
||||||
|
if domain == "emails":
|
||||||
|
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
elif domain == "finance":
|
||||||
|
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
elif domain == "multimodal":
|
||||||
|
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
else:
|
||||||
|
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
|
||||||
|
"""Simple RAG evaluation with timing"""
|
||||||
|
search_times = []
|
||||||
|
gen_times = []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Search
|
||||||
|
start = time.time()
|
||||||
|
docs = searcher.search(query, top_k=top_k, complexity=complexity)
|
||||||
|
search_time = time.time() - start
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
context = "\n\n".join([doc.text for doc in docs])
|
||||||
|
prompt = create_prompt(context, query, domain)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
response = llm_func(prompt)
|
||||||
|
gen_time = time.time() - start
|
||||||
|
|
||||||
|
search_times.append(search_time)
|
||||||
|
gen_times.append(gen_time)
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times),
|
||||||
|
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct"):
|
||||||
|
"""Load Qwen2.5-VL multimodal model"""
|
||||||
|
if not HF_AVAILABLE:
|
||||||
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
|
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
|
||||||
|
|
||||||
|
# Fallback to specific class
|
||||||
|
try:
|
||||||
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e2:
|
||||||
|
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
|
||||||
|
"""Generate with Qwen2.5-VL multimodal model"""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
if image_path:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||||||
|
else:
|
||||||
|
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode response
|
||||||
|
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
|
||||||
|
response = processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
|
||||||
|
"""Create prompt for multimodal RAG"""
|
||||||
|
if task_type == "images":
|
||||||
|
return f"""Based on the retrieved images and their descriptions, answer the following question.
|
||||||
|
|
||||||
|
Retrieved Image Descriptions:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Question: {query}
|
||||||
|
|
||||||
|
Provide a detailed answer based on the visual content described above."""
|
||||||
|
|
||||||
|
return f"Context: {context}\nQuestion: {query}\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
|
||||||
|
"""Evaluate multimodal RAG with Qwen2.5-VL"""
|
||||||
|
search_times = []
|
||||||
|
gen_times = []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, query_item in enumerate(queries):
|
||||||
|
# Handle both string and dict formats for queries
|
||||||
|
if isinstance(query_item, dict):
|
||||||
|
query = query_item.get("query", "")
|
||||||
|
image_path = query_item.get("image_path") # Optional reference image
|
||||||
|
else:
|
||||||
|
query = str(query_item)
|
||||||
|
image_path = None
|
||||||
|
|
||||||
|
# Search
|
||||||
|
start_time = time.time()
|
||||||
|
search_results = searcher.search(query, top_k=3, complexity=complexity)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
|
||||||
|
# Prepare context from search results
|
||||||
|
context_parts = []
|
||||||
|
for result in search_results:
|
||||||
|
context_parts.append(f"- {result.text}")
|
||||||
|
context = "\n".join(context_parts)
|
||||||
|
|
||||||
|
# Generate with multimodal model
|
||||||
|
start_time = time.time()
|
||||||
|
if processor and model:
|
||||||
|
prompt = create_multimodal_prompt(context, query, context_parts)
|
||||||
|
response = generate_qwen_vl(processor, model, prompt, image_path)
|
||||||
|
else:
|
||||||
|
response = f"Context: {context}"
|
||||||
|
gen_time = time.time() - start_time
|
||||||
|
|
||||||
|
gen_times.append(gen_time)
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times),
|
||||||
|
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
@@ -2,20 +2,20 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from contextlib import contextmanager
|
from transformers import AutoModel, BitsAndBytesConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str
|
model_path: str
|
||||||
batch_sizes: List[int]
|
batch_sizes: list[int]
|
||||||
seq_length: int
|
seq_length: int
|
||||||
num_runs: int
|
num_runs: int
|
||||||
use_fp16: bool = True
|
use_fp16: bool = True
|
||||||
@@ -32,13 +32,11 @@ class GraphContainer:
|
|||||||
def __init__(self, model: nn.Module, seq_length: int):
|
def __init__(self, model: nn.Module, seq_length: int):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.graphs: Dict[int, 'GraphWrapper'] = {}
|
self.graphs: dict[int, GraphWrapper] = {}
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'GraphWrapper':
|
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||||
if batch_size not in self.graphs:
|
if batch_size not in self.graphs:
|
||||||
self.graphs[batch_size] = GraphWrapper(
|
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||||
self.model, batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[batch_size]
|
return self.graphs[batch_size]
|
||||||
|
|
||||||
|
|
||||||
@@ -55,13 +53,13 @@ class GraphWrapper:
|
|||||||
self._warmup()
|
self._warmup()
|
||||||
|
|
||||||
# Only use CUDA graphs on NVIDIA GPUs
|
# Only use CUDA graphs on NVIDIA GPUs
|
||||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'):
|
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||||
# Capture graph
|
# Capture graph
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self.graph):
|
with torch.cuda.graph(self.graph):
|
||||||
self.static_output = self.model(
|
self.static_output = self.model(
|
||||||
input_ids=self.static_input,
|
input_ids=self.static_input,
|
||||||
attention_mask=self.static_attention_mask
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
self.use_cuda_graph = True
|
self.use_cuda_graph = True
|
||||||
else:
|
else:
|
||||||
@@ -79,9 +77,7 @@ class GraphWrapper:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000, (batch_size, seq_length),
|
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||||
device=self.device,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
def _warmup(self, num_warmup: int = 3):
|
||||||
@@ -89,7 +85,7 @@ class GraphWrapper:
|
|||||||
for _ in range(num_warmup):
|
for _ in range(num_warmup):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=self.static_input,
|
input_ids=self.static_input,
|
||||||
attention_mask=self.static_attention_mask
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -133,8 +129,12 @@ class ModelOptimizer:
|
|||||||
print("- Using FP16 precision")
|
print("- Using FP16 precision")
|
||||||
|
|
||||||
# Check if using SDPA (only on CUDA)
|
# Check if using SDPA (only on CUDA)
|
||||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
@@ -142,7 +142,8 @@ class ModelOptimizer:
|
|||||||
# Flash Attention (only on CUDA)
|
# Flash Attention (only on CUDA)
|
||||||
if config.use_flash_attention and torch.cuda.is_available():
|
if config.use_flash_attention and torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attention import FlashAttention
|
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||||
|
|
||||||
print("- Flash Attention 2 available")
|
print("- Flash Attention 2 available")
|
||||||
if hasattr(model.config, "attention_mode"):
|
if hasattr(model.config, "attention_mode"):
|
||||||
model.config.attention_mode = "flash_attention_2"
|
model.config.attention_mode = "flash_attention_2"
|
||||||
@@ -153,8 +154,9 @@ class ModelOptimizer:
|
|||||||
# Memory efficient attention (only on CUDA)
|
# Memory efficient attention (only on CUDA)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from xformers.ops import memory_efficient_attention
|
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
else:
|
else:
|
||||||
@@ -220,7 +222,7 @@ class Benchmark:
|
|||||||
self.graphs = None
|
self.graphs = None
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
print(f"ERROR in benchmark initialization: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
def _load_model(self) -> nn.Module:
|
||||||
@@ -230,15 +232,17 @@ class Benchmark:
|
|||||||
# Int4 quantization using HuggingFace integration
|
# Int4 quantization using HuggingFace integration
|
||||||
if self.config.use_int4:
|
if self.config.use_int4:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||||
|
|
||||||
# 检查是否使用自定义的8bit量化
|
# Check if using custom 8bit quantization
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||||
|
|
||||||
# 加载原始模型(不使用量化配置)
|
# Load original model (without quantization config)
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# set default to half
|
# set default to half
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||||
@@ -247,52 +251,58 @@ class Benchmark:
|
|||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义替换函数
|
# Define replacement function
|
||||||
def replace_linear_with_linear8bitlt(model):
|
def replace_linear_with_linear8bitlt(model):
|
||||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||||
for name, module in list(model.named_children()):
|
for name, module in list(model.named_children()):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
# 获取原始线性层的参数
|
# Get original linear layer parameters
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
bias = module.bias is not None
|
bias = module.bias is not None
|
||||||
|
|
||||||
# 创建8bit线性层
|
# Create 8bit linear layer
|
||||||
# print size
|
# print size
|
||||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||||
new_module = bnb.nn.Linear8bitLt(
|
new_module = bnb.nn.Linear8bitLt(
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
has_fp16_weights=False
|
has_fp16_weights=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 复制权重和偏置
|
# Copy weights and bias
|
||||||
new_module.weight.data = module.weight.data
|
new_module.weight.data = module.weight.data
|
||||||
if bias:
|
if bias:
|
||||||
new_module.bias.data = module.bias.data
|
new_module.bias.data = module.bias.data
|
||||||
|
|
||||||
# 替换模块
|
# Replace module
|
||||||
setattr(model, name, new_module)
|
setattr(model, name, new_module)
|
||||||
else:
|
else:
|
||||||
# 递归处理子模块
|
# Process child modules recursively
|
||||||
replace_linear_with_linear8bitlt(module)
|
replace_linear_with_linear8bitlt(module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# 替换所有线性层
|
# Replace all linear layers
|
||||||
model = replace_linear_with_linear8bitlt(model)
|
model = replace_linear_with_linear8bitlt(model)
|
||||||
# add torch compile
|
# add torch compile
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
# 将模型移到GPU(量化发生在这里)
|
# Move model to GPU (quantization happens here)
|
||||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print("- All linear layers replaced with Linear8bitLt")
|
print("- All linear layers replaced with Linear8bitLt")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 使用原来的Int4量化方法
|
# Use original Int4 quantization method
|
||||||
print("- Using bitsandbytes for Int4 quantization")
|
print("- Using bitsandbytes for Int4 quantization")
|
||||||
|
|
||||||
# Create quantization config
|
# Create quantization config
|
||||||
@@ -302,7 +312,7 @@ class Benchmark:
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=compute_dtype,
|
bnb_4bit_compute_dtype=compute_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4"
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("- Quantization config:", quantization_config)
|
print("- Quantization config:", quantization_config)
|
||||||
@@ -312,7 +322,7 @@ class Benchmark:
|
|||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map="auto" # Let HF decide on device mapping
|
device_map="auto", # Let HF decide on device mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if model loaded successfully
|
# Check if model loaded successfully
|
||||||
@@ -324,7 +334,7 @@ class Benchmark:
|
|||||||
# Apply optimizations directly here
|
# Apply optimizations directly here
|
||||||
print("\nApplying model optimizations:")
|
print("\nApplying model optimizations:")
|
||||||
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||||
else:
|
else:
|
||||||
# Skip moving to GPU since device_map="auto" already did that
|
# Skip moving to GPU since device_map="auto" already did that
|
||||||
@@ -334,8 +344,12 @@ class Benchmark:
|
|||||||
print(f"- Using {compute_dtype} for compute dtype")
|
print(f"- Using {compute_dtype} for compute dtype")
|
||||||
|
|
||||||
# Check CUDA and SDPA
|
# Check CUDA and SDPA
|
||||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
@@ -343,8 +357,7 @@ class Benchmark:
|
|||||||
# Try xformers if available (only on CUDA)
|
# Try xformers if available (only on CUDA)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from xformers.ops import memory_efficient_attention
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
else:
|
else:
|
||||||
@@ -370,7 +383,7 @@ class Benchmark:
|
|||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map="auto"
|
device_map="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
@@ -389,6 +402,7 @@ class Benchmark:
|
|||||||
# Apply standard optimizations
|
# Apply standard optimizations
|
||||||
# set default to half
|
# set default to half
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model = ModelOptimizer.optimize(model, self.config)
|
model = ModelOptimizer.optimize(model, self.config)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@@ -403,25 +417,31 @@ class Benchmark:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR loading model: {str(e)}")
|
print(f"ERROR loading model: {e!s}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000,
|
0,
|
||||||
|
1000,
|
||||||
(batch_size, self.config.seq_length),
|
(batch_size, self.config.seq_length),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(
|
def _run_inference(
|
||||||
self,
|
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||||
input_ids: torch.Tensor,
|
) -> tuple[float, torch.Tensor]:
|
||||||
graph_wrapper: Optional[GraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
with torch.no_grad(), self.timer.timing():
|
||||||
@@ -432,7 +452,7 @@ class Benchmark:
|
|||||||
|
|
||||||
return self.timer.elapsed_time(), output
|
return self.timer.elapsed_time(), output
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
# Reset peak memory stats
|
# Reset peak memory stats
|
||||||
@@ -450,9 +470,7 @@ class Benchmark:
|
|||||||
|
|
||||||
# Get or create graph for this batch size
|
# Get or create graph for this batch size
|
||||||
graph_wrapper = (
|
graph_wrapper = (
|
||||||
self.graphs.get_or_create(batch_size)
|
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||||
if self.graphs is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
# Pre-allocate input tensor
|
||||||
@@ -490,7 +508,7 @@ class Benchmark:
|
|||||||
|
|
||||||
# Log memory usage
|
# Log memory usage
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
# MPS doesn't have max_memory_allocated, use 0
|
# MPS doesn't have max_memory_allocated, use 0
|
||||||
peak_memory_gb = 0.0
|
peak_memory_gb = 0.0
|
||||||
@@ -604,7 +622,15 @@ def main():
|
|||||||
os.makedirs("results", exist_ok=True)
|
os.makedirs("results", exist_ok=True)
|
||||||
|
|
||||||
# Generate filename based on configuration
|
# Generate filename based on configuration
|
||||||
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32"
|
precision_type = (
|
||||||
|
"int4"
|
||||||
|
if config.use_int4
|
||||||
|
else "int8"
|
||||||
|
if config.use_int8
|
||||||
|
else "fp16"
|
||||||
|
if config.use_fp16
|
||||||
|
else "fp32"
|
||||||
|
)
|
||||||
model_name = os.path.basename(config.model_path)
|
model_name = os.path.basename(config.model_path)
|
||||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||||
|
|
||||||
@@ -612,17 +638,20 @@ def main():
|
|||||||
with open(output_file, "w") as f:
|
with open(output_file, "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
"config": {
|
||||||
"results": {str(k): v for k, v in results.items()}
|
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||||
|
},
|
||||||
|
"results": {str(k): v for k, v in results.items()},
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
indent=2
|
indent=2,
|
||||||
)
|
)
|
||||||
print(f"Results saved to {output_file}")
|
print(f"Results saved to {output_file}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Benchmark failed: {e}")
|
print(f"Benchmark failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
@@ -5,24 +5,21 @@ It correctly compares results by fetching the text content for both the new sear
|
|||||||
results and the golden standard results, making the comparison robust to ID changes.
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from leann.api import LeannSearcher, LeannBuilder
|
import numpy as np
|
||||||
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
if not data_root.exists():
|
if not data_root.exists():
|
||||||
print(f"Data directory '{data_root}' not found.")
|
print(f"Data directory '{data_root}' not found.")
|
||||||
print(
|
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
@@ -56,14 +53,14 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
|||||||
print(
|
print(
|
||||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
)
|
)
|
||||||
print("uv pip install -e '.[dev]'")
|
print("uv sync --only-group dev")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred during data download: {e}")
|
print(f"An error occurred during data download: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||||
"""Download embeddings files specifically."""
|
"""Download embeddings files specifically."""
|
||||||
embeddings_dir = data_root / "embeddings"
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
@@ -101,7 +98,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
|||||||
|
|
||||||
|
|
||||||
# --- Helper Function to get Golden Passages ---
|
# --- Helper Function to get Golden Passages ---
|
||||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||||
"""
|
"""
|
||||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
passage manager.
|
passage manager.
|
||||||
@@ -113,24 +110,20 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
|||||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
golden_texts.add(passage_data["text"])
|
golden_texts.add(passage_data["text"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(
|
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
|
||||||
)
|
|
||||||
return golden_texts
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
def load_queries(file_path: Path) -> List[str]:
|
def load_queries(file_path: Path) -> list[str]:
|
||||||
queries = []
|
queries = []
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
queries.append(data["query"])
|
queries.append(data["query"])
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
def build_index_from_embeddings(
|
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Build a LEANN index from pre-computed embeddings.
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
@@ -173,9 +166,7 @@ def build_index_from_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||||
description="Run recall evaluation on a LEANN index."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"index_path",
|
"index_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -202,26 +193,41 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Batch size for HNSW batched search (0 disables batching)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||||
|
default="ollama",
|
||||||
|
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default="qwen3:1.7b",
|
||||||
|
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
# Assumes a project structure where the script is in 'examples/'
|
# Assumes a project structure where the script is in 'benchmarks/'
|
||||||
# and data is in 'data/' at the project root.
|
# and evaluation data is in 'benchmarks/data/'.
|
||||||
project_root = Path(__file__).resolve().parent.parent
|
script_dir = Path(__file__).resolve().parent
|
||||||
data_root = project_root / "data"
|
data_root = script_dir / "data"
|
||||||
|
|
||||||
# Download data based on mode
|
# Download data based on mode
|
||||||
if args.mode == "build":
|
if args.mode == "build":
|
||||||
# For building mode, we need embeddings
|
# For building mode, we need embeddings
|
||||||
download_data_if_needed(
|
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||||
data_root, download_embeddings=False
|
|
||||||
) # Basic data first
|
|
||||||
|
|
||||||
# Auto-detect dataset type and download embeddings
|
# Auto-detect dataset type and download embeddings
|
||||||
if args.embeddings_file:
|
if args.embeddings_file:
|
||||||
@@ -262,9 +268,7 @@ def main():
|
|||||||
print(f"Index built successfully: {built_index_path}")
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
# Ask if user wants to run evaluation
|
# Ask if user wants to run evaluation
|
||||||
eval_response = (
|
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
|
||||||
)
|
|
||||||
if eval_response != "y":
|
if eval_response != "y":
|
||||||
print("Index building complete. Exiting.")
|
print("Index building complete. Exiting.")
|
||||||
return
|
return
|
||||||
@@ -293,11 +297,9 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not args.index_path:
|
if not args.index_path:
|
||||||
|
print("No indices found. The data download should have included pre-built indices.")
|
||||||
print(
|
print(
|
||||||
"No indices found. The data download should have included pre-built indices."
|
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||||
)
|
|
||||||
print(
|
|
||||||
"Please check the data/indices/ directory or provide --index-path manually."
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@@ -310,14 +312,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
# Fallback: try to infer from the index directory name
|
# Fallback: try to infer from the index directory name
|
||||||
dataset_type = Path(args.index_path).name
|
dataset_type = Path(args.index_path).name
|
||||||
print(
|
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
golden_results_file = (
|
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
print(f"INFO: Using queries file: {queries_file}")
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
@@ -327,7 +325,7 @@ def main():
|
|||||||
searcher = LeannSearcher(args.index_path)
|
searcher = LeannSearcher(args.index_path)
|
||||||
queries = load_queries(queries_file)
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
with open(golden_results_file, "r") as f:
|
with open(golden_results_file) as f:
|
||||||
golden_results_data = json.load(f)
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
num_eval_queries = min(args.num_queries, len(queries))
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
@@ -340,10 +338,23 @@ def main():
|
|||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(
|
new_results = searcher.search(
|
||||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
)
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||||
|
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||||
|
answer = chat.ask(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
print(f"Answer: {answer}")
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
@@ -1,26 +1,27 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
# Add MLX imports
|
# Add MLX imports
|
||||||
try:
|
try:
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_lm.utils import load
|
from mlx_lm.utils import load
|
||||||
|
|
||||||
MLX_AVAILABLE = True
|
MLX_AVAILABLE = True
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||||
MLX_AVAILABLE = False
|
MLX_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever"
|
model_path: str = "facebook/contriever-msmarco"
|
||||||
batch_sizes: List[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
use_fp16: bool = True
|
use_fp16: bool = True
|
||||||
@@ -33,7 +34,8 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
"""MLX-specific benchmark for embedding models"""
|
"""MLX-specific benchmark for embedding models"""
|
||||||
@@ -55,11 +57,7 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int):
|
def _create_random_batch(self, batch_size: int):
|
||||||
"""Create random input batches for MLX testing - same as PyTorch"""
|
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||||
return torch.randint(
|
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||||
0, 1000,
|
|
||||||
(batch_size, self.config.seq_length),
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
"""Run MLX inference with same input as PyTorch"""
|
"""Run MLX inference with same input as PyTorch"""
|
||||||
@@ -82,12 +80,12 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"MLX inference error: {e}")
|
print(f"MLX inference error: {e}")
|
||||||
return float('inf')
|
return float("inf")
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
"""Run the MLX benchmark across all batch sizes"""
|
"""Run the MLX benchmark across all batch sizes"""
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
@@ -111,10 +109,10 @@ class MLXBenchmark:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Run benchmark
|
# Run benchmark
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||||
try:
|
try:
|
||||||
elapsed_time = self._run_inference(input_ids)
|
elapsed_time = self._run_inference(input_ids)
|
||||||
if elapsed_time != float('inf'):
|
if elapsed_time != float("inf"):
|
||||||
times.append(elapsed_time)
|
times.append(elapsed_time)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during MLX inference: {e}")
|
print(f"Error during MLX inference: {e}")
|
||||||
@@ -145,16 +143,22 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
def __init__(self, config: BenchmarkConfig):
|
def __init__(self, config: BenchmarkConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
self.device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
self.model = self._load_model()
|
self.model = self._load_model()
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
def _load_model(self) -> nn.Module:
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
print(f"Loading model from {self.config.model_path}...")
|
||||||
|
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
model = AutoModel.from_pretrained(self.config.model_path)
|
||||||
if self.config.use_fp16:
|
if self.config.use_fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@@ -166,23 +170,30 @@ class Benchmark:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000,
|
0,
|
||||||
|
1000,
|
||||||
(batch_size, self.config.seq_length),
|
(batch_size, self.config.seq_length),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
# print shape of input_ids and attention_mask
|
||||||
|
print(f"input_ids shape: {input_ids.shape}")
|
||||||
|
print(f"attention_mask shape: {attention_mask.shape}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -194,7 +205,7 @@ class Benchmark:
|
|||||||
|
|
||||||
input_ids = self._create_random_batch(batch_size)
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||||
try:
|
try:
|
||||||
elapsed_time = self._run_inference(input_ids)
|
elapsed_time = self._run_inference(input_ids)
|
||||||
times.append(elapsed_time)
|
times.append(elapsed_time)
|
||||||
@@ -219,7 +230,7 @@ class Benchmark:
|
|||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
else:
|
else:
|
||||||
peak_memory_gb = 0.0
|
peak_memory_gb = 0.0
|
||||||
|
|
||||||
@@ -228,6 +239,7 @@ class Benchmark:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark():
|
def run_benchmark():
|
||||||
"""Main function to run the benchmark with optimized parameters."""
|
"""Main function to run the benchmark with optimized parameters."""
|
||||||
config = BenchmarkConfig()
|
config = BenchmarkConfig()
|
||||||
@@ -242,16 +254,13 @@ def run_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": max_throughput,
|
"max_throughput": max_throughput,
|
||||||
"avg_throughput": avg_throughput,
|
"avg_throughput": avg_throughput,
|
||||||
"results": results
|
"results": results,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Benchmark failed: {e}")
|
print(f"Benchmark failed: {e}")
|
||||||
return {
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
"max_throughput": 0.0,
|
|
||||||
"avg_throughput": 0.0,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_mlx_benchmark():
|
def run_mlx_benchmark():
|
||||||
"""Run MLX-specific benchmark"""
|
"""Run MLX-specific benchmark"""
|
||||||
@@ -260,13 +269,10 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": 0.0,
|
"max_throughput": 0.0,
|
||||||
"avg_throughput": 0.0,
|
"avg_throughput": 0.0,
|
||||||
"error": "MLX not available"
|
"error": "MLX not available",
|
||||||
}
|
}
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||||
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
|
|
||||||
use_mlx=True
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
benchmark = MLXBenchmark(config)
|
benchmark = MLXBenchmark(config)
|
||||||
@@ -276,7 +282,7 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": 0.0,
|
"max_throughput": 0.0,
|
||||||
"avg_throughput": 0.0,
|
"avg_throughput": 0.0,
|
||||||
"error": "No valid results"
|
"error": "No valid results",
|
||||||
}
|
}
|
||||||
|
|
||||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
@@ -285,16 +291,13 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": max_throughput,
|
"max_throughput": max_throughput,
|
||||||
"avg_throughput": avg_throughput,
|
"avg_throughput": avg_throughput,
|
||||||
"results": results
|
"results": results,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"MLX benchmark failed: {e}")
|
print(f"MLX benchmark failed: {e}")
|
||||||
return {
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
"max_throughput": 0.0,
|
|
||||||
"avg_throughput": 0.0,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("=== PyTorch Benchmark ===")
|
print("=== PyTorch Benchmark ===")
|
||||||
@@ -308,7 +311,7 @@ if __name__ == "__main__":
|
|||||||
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
|
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||||
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
|
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||||
print(f"\n=== Comparison ===")
|
print("\n=== Comparison ===")
|
||||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||||
82
data/.gitattributes
vendored
82
data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
|||||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.model filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
||||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - uncompressed
|
|
||||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - compressed
|
|
||||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - uncompressed
|
|
||||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.png filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - compressed
|
|
||||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Video files - compressed
|
|
||||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
BIN
data/2501.14312v1 (1).pdf
Normal file
BIN
data/2501.14312v1 (1).pdf
Normal file
Binary file not shown.
7905
data/2506.08276v1.pdf
Normal file
7905
data/2506.08276v1.pdf
Normal file
File diff suppressed because it is too large
Load Diff
14905
data/PrideandPrejudice.txt
Normal file
14905
data/PrideandPrejudice.txt
Normal file
File diff suppressed because it is too large
Load Diff
105
demo.ipynb
105
demo.ipynb
@@ -1,37 +1,116 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Quick Start \n",
|
||||||
|
"\n",
|
||||||
|
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||||
|
"\n",
|
||||||
|
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder, LeannSearcher, LeannChat\n",
|
"# install this if you are using colab\n",
|
||||||
|
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||||
|
"! uv pip install leann --no-deps\n",
|
||||||
|
"# For Colab environment, we need to set some environment variables\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||||
|
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build the index"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannBuilder\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 1. Build the index (no embeddings stored!)\n",
|
|
||||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
"builder.add_text(\"C# is a powerful programming language\")\n",
|
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||||
"builder.add_text(\"Python is a powerful programming language and it is very popular\")\n",
|
"builder.add_text(\n",
|
||||||
|
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||||
|
")\n",
|
||||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
"builder.build_index(\"knowledge.leann\")\n",
|
"builder.build_index(INDEX_PATH)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Search with real-time embeddings"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannSearcher\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 2. Search with real-time embeddings\n",
|
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
|
||||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
|
"results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chat with LEANN using retrieved results"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from leann.api import LeannChat\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 3. Chat with LEANN using retrieved results\n",
|
|
||||||
"llm_config = {\n",
|
"llm_config = {\n",
|
||||||
" \"type\": \"ollama\",\n",
|
" \"type\": \"hf\",\n",
|
||||||
" \"model\": \"llama3.2:1b\"\n",
|
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||||
"response = chat.ask(\n",
|
"response = chat.ask(\n",
|
||||||
" \"Compare the two retrieved programming languages and say which one is more popular today.\",\n",
|
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||||
" top_k=2,\n",
|
" top_k=2,\n",
|
||||||
")"
|
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||||
|
")\n",
|
||||||
|
"response"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
223
docs/CONTRIBUTING.md
Normal file
223
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# 🤝 Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
|
## Ways to Contribute
|
||||||
|
|
||||||
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
|
## 🚀 Development Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. **Install uv** (fast Python package installer):
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Clone the repository**:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||||
|
cd LEANN-RAG
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Install system dependencies**:
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ubuntu/Debian:**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||||
|
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Build from source**:
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
|
|
||||||
|
# Ubuntu/Debian
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔨 Pre-commit Hooks
|
||||||
|
|
||||||
|
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||||
|
|
||||||
|
### Setup Pre-commit
|
||||||
|
|
||||||
|
1. **Install pre-commit tools**:
|
||||||
|
```bash
|
||||||
|
uv sync lint
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install the git hooks**:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Run pre-commit manually** (optional):
|
||||||
|
```bash
|
||||||
|
uv run pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-commit Checks
|
||||||
|
|
||||||
|
Our pre-commit configuration includes:
|
||||||
|
- **Trailing whitespace removal**
|
||||||
|
- **End-of-file fixing**
|
||||||
|
- **YAML validation**
|
||||||
|
- **Large file prevention**
|
||||||
|
- **Merge conflict detection**
|
||||||
|
- **Debug statement detection**
|
||||||
|
- **Code formatting with ruff**
|
||||||
|
- **Code linting with ruff**
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install test tools only (no project runtime)
|
||||||
|
uv sync --group test
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
uv run pytest test/test_filename.py
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
uv run pytest --cov=leann
|
||||||
|
```
|
||||||
|
|
||||||
|
### Writing Tests
|
||||||
|
|
||||||
|
- Place tests in the `test/` directory
|
||||||
|
- Follow the naming convention `test_*.py`
|
||||||
|
- Use descriptive test names that explain what's being tested
|
||||||
|
- Include both positive and negative test cases
|
||||||
|
|
||||||
|
## 📝 Code Style
|
||||||
|
|
||||||
|
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||||
|
|
||||||
|
### Format Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Format all files
|
||||||
|
ruff format
|
||||||
|
|
||||||
|
# Check formatting without changing files
|
||||||
|
ruff format --check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lint Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run linter with auto-fix
|
||||||
|
ruff check --fix
|
||||||
|
|
||||||
|
# Just check without fixing
|
||||||
|
ruff check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Style Guidelines
|
||||||
|
|
||||||
|
- Follow PEP 8 conventions
|
||||||
|
- Use descriptive variable names
|
||||||
|
- Add type hints where appropriate
|
||||||
|
- Write docstrings for all public functions and classes
|
||||||
|
- Keep functions focused and single-purpose
|
||||||
|
|
||||||
|
## 🚦 CI/CD
|
||||||
|
|
||||||
|
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||||
|
|
||||||
|
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||||
|
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||||
|
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||||
|
4. **Wheel building**: Ensures packages can be built and distributed
|
||||||
|
|
||||||
|
### CI Commands
|
||||||
|
|
||||||
|
The CI uses the same commands as pre-commit to ensure consistency:
|
||||||
|
```bash
|
||||||
|
# Linting
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
# Format checking
|
||||||
|
ruff format --check .
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure your code passes these checks locally before pushing!
|
||||||
|
|
||||||
|
## 🔄 Pull Request Process
|
||||||
|
|
||||||
|
1. **Fork the repository** and create your branch from `main`:
|
||||||
|
```bash
|
||||||
|
git checkout -b feature/your-feature-name
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Make your changes**:
|
||||||
|
- Write clean, documented code
|
||||||
|
- Add tests for new functionality
|
||||||
|
- Update documentation as needed
|
||||||
|
|
||||||
|
3. **Run pre-commit checks**:
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Test your changes**:
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Commit with descriptive messages**:
|
||||||
|
```bash
|
||||||
|
git commit -m "feat: add new search algorithm"
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||||
|
- `feat:` for new features
|
||||||
|
- `fix:` for bug fixes
|
||||||
|
- `docs:` for documentation changes
|
||||||
|
- `test:` for test additions/changes
|
||||||
|
- `refactor:` for code refactoring
|
||||||
|
- `perf:` for performance improvements
|
||||||
|
|
||||||
|
6. **Push and create a pull request**:
|
||||||
|
- Provide a clear description of your changes
|
||||||
|
- Reference any related issues
|
||||||
|
- Include examples or screenshots if applicable
|
||||||
|
|
||||||
|
## 📚 Documentation
|
||||||
|
|
||||||
|
When adding new features or making significant changes:
|
||||||
|
|
||||||
|
1. Update relevant documentation in `/docs`
|
||||||
|
2. Add docstrings to new functions/classes
|
||||||
|
3. Update README.md if needed
|
||||||
|
4. Include usage examples
|
||||||
|
|
||||||
|
## 🤔 Getting Help
|
||||||
|
|
||||||
|
- **Discord**: Join our community for discussions
|
||||||
|
- **Issues**: Check existing issues or create a new one
|
||||||
|
- **Discussions**: For general questions and ideas
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||||
22
docs/RELEASE.md
Normal file
22
docs/RELEASE.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Release Guide
|
||||||
|
|
||||||
|
## Setup (One-time)
|
||||||
|
|
||||||
|
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||||
|
1. Get token: https://pypi.org/manage/account/token/
|
||||||
|
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||||
|
|
||||||
|
## Release (One-click)
|
||||||
|
|
||||||
|
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||||
|
2. Click "Run workflow"
|
||||||
|
3. Enter version: `0.1.2`
|
||||||
|
4. Click green "Run workflow" button
|
||||||
|
|
||||||
|
That's it! The workflow will automatically:
|
||||||
|
- ✅ Update version in all packages
|
||||||
|
- ✅ Build all packages
|
||||||
|
- ✅ Publish to PyPI
|
||||||
|
- ✅ Create GitHub tag and release
|
||||||
|
|
||||||
|
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||||
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
123
docs/THINKING_BUDGET_FEATURE.md
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
# Thinking Budget Feature Implementation
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This document describes the implementation of the **thinking budget** feature for LEANN, which allows users to control the computational effort for reasoning models like GPT-Oss:20b.
|
||||||
|
|
||||||
|
## Feature Description
|
||||||
|
|
||||||
|
The thinking budget feature provides three levels of computational effort for reasoning models:
|
||||||
|
- **`low`**: Fast responses, basic reasoning (default for simple queries)
|
||||||
|
- **`medium`**: Balanced speed and reasoning depth
|
||||||
|
- **`high`**: Maximum reasoning effort, best for complex analytical questions
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
### 1. Command Line Interface
|
||||||
|
|
||||||
|
Added `--thinking-budget` parameter to both CLI and RAG examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# LEANN CLI
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# RAG Examples
|
||||||
|
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
python apps/document_rag.py --llm openai --llm-model o3 --thinking-budget medium
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. LLM Backend Support
|
||||||
|
|
||||||
|
#### Ollama Backend (`packages/leann-core/src/leann/chat.py`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
# Handle thinking budget for reasoning models
|
||||||
|
options = kwargs.copy()
|
||||||
|
thinking_budget = kwargs.get("thinking_budget")
|
||||||
|
if thinking_budget:
|
||||||
|
options.pop("thinking_budget", None)
|
||||||
|
if thinking_budget in ["low", "medium", "high"]:
|
||||||
|
options["reasoning"] = {"effort": thinking_budget, "exclude": False}
|
||||||
|
```
|
||||||
|
|
||||||
|
**API Format**: Uses Ollama's `reasoning` parameter with `effort` and `exclude` fields.
|
||||||
|
|
||||||
|
#### OpenAI Backend (`packages/leann-core/src/leann/chat.py`)
|
||||||
|
|
||||||
|
```python
|
||||||
|
def ask(self, prompt: str, **kwargs) -> str:
|
||||||
|
# Handle thinking budget for reasoning models
|
||||||
|
thinking_budget = kwargs.get("thinking_budget")
|
||||||
|
if thinking_budget and thinking_budget in ["low", "medium", "high"]:
|
||||||
|
# Check if this is an o-series model
|
||||||
|
o_series_models = ["o3", "o3-mini", "o4-mini", "o1", "o3-pro", "o3-deep-research"]
|
||||||
|
if any(model in self.model for model in o_series_models):
|
||||||
|
params["reasoning_effort"] = thinking_budget
|
||||||
|
```
|
||||||
|
|
||||||
|
**API Format**: Uses OpenAI's `reasoning_effort` parameter for o-series models.
|
||||||
|
|
||||||
|
### 3. Parameter Propagation
|
||||||
|
|
||||||
|
The thinking budget parameter is properly propagated through the LEANN architecture:
|
||||||
|
|
||||||
|
1. **CLI** (`packages/leann-core/src/leann/cli.py`): Captures `--thinking-budget` argument
|
||||||
|
2. **Base RAG** (`apps/base_rag_example.py`): Adds parameter to argument parser
|
||||||
|
3. **LeannChat** (`packages/leann-core/src/leann/api.py`): Passes `llm_kwargs` to LLM
|
||||||
|
4. **LLM Interface**: Handles the parameter in backend-specific implementations
|
||||||
|
|
||||||
|
## Files Modified
|
||||||
|
|
||||||
|
### Core Implementation
|
||||||
|
- `packages/leann-core/src/leann/chat.py`: Added thinking budget support to OllamaChat and OpenAIChat
|
||||||
|
- `packages/leann-core/src/leann/cli.py`: Added `--thinking-budget` argument
|
||||||
|
- `apps/base_rag_example.py`: Added thinking budget parameter to RAG examples
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
- `README.md`: Added thinking budget parameter to usage examples
|
||||||
|
- `docs/configuration-guide.md`: Added detailed documentation and usage guidelines
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
- `examples/thinking_budget_demo.py`: Comprehensive demo script with usage examples
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
```bash
|
||||||
|
# High reasoning effort for complex questions
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# Medium reasoning for balanced performance
|
||||||
|
leann ask my-index --llm openai --model gpt-4o --thinking-budget medium
|
||||||
|
|
||||||
|
# Low reasoning for fast responses
|
||||||
|
leann ask my-index --llm ollama --model gpt-oss:20b --thinking-budget low
|
||||||
|
```
|
||||||
|
|
||||||
|
### RAG Examples
|
||||||
|
```bash
|
||||||
|
# Email RAG with high reasoning
|
||||||
|
python apps/email_rag.py --llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
|
||||||
|
# Document RAG with medium reasoning
|
||||||
|
python apps/document_rag.py --llm openai --llm-model gpt-4o --thinking-budget medium
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
### Ollama Models
|
||||||
|
- **GPT-Oss:20b**: Primary target model with reasoning capabilities
|
||||||
|
- **Other reasoning models**: Any Ollama model that supports the `reasoning` parameter
|
||||||
|
|
||||||
|
### OpenAI Models
|
||||||
|
- **o3, o3-mini, o4-mini, o1**: o-series reasoning models with `reasoning_effort` parameter
|
||||||
|
- **GPT-OSS models**: Models that support reasoning capabilities
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
The implementation includes comprehensive testing:
|
||||||
|
- Parameter handling verification
|
||||||
|
- Backend-specific API format validation
|
||||||
|
- CLI argument parsing tests
|
||||||
|
- Integration with existing LEANN architecture
|
||||||
143
docs/ast_chunking_guide.md
Normal file
143
docs/ast_chunking_guide.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# AST-Aware Code chunking guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This guide covers best practices for using AST-aware code chunking in LEANN. AST chunking provides better semantic understanding of code structure compared to traditional text-based chunking.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Enable AST chunking for mixed content (code + docs)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --data-dir ./my_project
|
||||||
|
|
||||||
|
# Specialized code repository indexing
|
||||||
|
python -m apps.code_rag --repo-dir ./my_codebase
|
||||||
|
|
||||||
|
# Global CLI with AST support
|
||||||
|
leann build my-code-index --docs ./src --use-ast-chunking
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install LEANN with AST chunking support
|
||||||
|
uv pip install -e "."
|
||||||
|
```
|
||||||
|
|
||||||
|
#### For normal users (PyPI install)
|
||||||
|
- Use `pip install leann` or `uv pip install leann`.
|
||||||
|
- `astchunk` is pulled automatically from PyPI as a dependency; no extra steps.
|
||||||
|
|
||||||
|
#### For developers (from source, editable)
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yichuan-w/LEANN.git leann
|
||||||
|
cd leann
|
||||||
|
git submodule update --init --recursive
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
- This repo vendors `astchunk` as a git submodule at `packages/astchunk-leann` (our fork).
|
||||||
|
- `[tool.uv.sources]` maps the `astchunk` package to that path in editable mode.
|
||||||
|
- You can edit code under `packages/astchunk-leann` and Python will use your changes immediately (no separate `pip install astchunk` needed).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### When to Use AST Chunking
|
||||||
|
|
||||||
|
✅ **Recommended for:**
|
||||||
|
- Code repositories with multiple languages
|
||||||
|
- Mixed documentation and code content
|
||||||
|
- Complex codebases with deep function/class hierarchies
|
||||||
|
- When working with Claude Code for code assistance
|
||||||
|
|
||||||
|
❌ **Not recommended for:**
|
||||||
|
- Pure text documents
|
||||||
|
- Very large files (>1MB)
|
||||||
|
- Languages not supported by tree-sitter
|
||||||
|
|
||||||
|
### Optimal Configuration
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Recommended settings for most codebases
|
||||||
|
python -m apps.code_rag \
|
||||||
|
--repo-dir ./src \
|
||||||
|
--ast-chunk-size 768 \
|
||||||
|
--ast-chunk-overlap 96 \
|
||||||
|
--exclude-dirs .git __pycache__ node_modules build dist
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Languages
|
||||||
|
|
||||||
|
| Extension | Language | Status |
|
||||||
|
|-----------|----------|--------|
|
||||||
|
| `.py` | Python | ✅ Full support |
|
||||||
|
| `.java` | Java | ✅ Full support |
|
||||||
|
| `.cs` | C# | ✅ Full support |
|
||||||
|
| `.ts`, `.tsx` | TypeScript | ✅ Full support |
|
||||||
|
| `.js`, `.jsx` | JavaScript | ✅ Via TypeScript parser |
|
||||||
|
|
||||||
|
## Integration Examples
|
||||||
|
|
||||||
|
### Document RAG with Code Support
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Enable code chunking in document RAG
|
||||||
|
python -m apps.document_rag \
|
||||||
|
--enable-code-chunking \
|
||||||
|
--data-dir ./project \
|
||||||
|
--query "How does authentication work in the codebase?"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claude Code Integration
|
||||||
|
|
||||||
|
When using with Claude Code MCP server, AST chunking provides better context for:
|
||||||
|
- Code completion and suggestions
|
||||||
|
- Bug analysis and debugging
|
||||||
|
- Architecture understanding
|
||||||
|
- Refactoring assistance
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Fallback to Traditional Chunking**
|
||||||
|
- Normal behavior for unsupported languages
|
||||||
|
- Check logs for specific language support
|
||||||
|
|
||||||
|
2. **Performance with Large Files**
|
||||||
|
- Adjust `--max-file-size` parameter
|
||||||
|
- Use `--exclude-dirs` to skip unnecessary directories
|
||||||
|
|
||||||
|
3. **Quality Issues**
|
||||||
|
- Try different `--ast-chunk-size` values (512, 768, 1024)
|
||||||
|
- Adjust overlap for better context preservation
|
||||||
|
|
||||||
|
### Debug Mode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export LEANN_LOG_LEVEL=DEBUG
|
||||||
|
python -m apps.code_rag --repo-dir ./my_code
|
||||||
|
```
|
||||||
|
|
||||||
|
## Migration from Traditional Chunking
|
||||||
|
|
||||||
|
Existing workflows continue to work without changes. To enable AST chunking:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Before
|
||||||
|
python -m apps.document_rag --chunk-size 256
|
||||||
|
|
||||||
|
# After (maintains traditional chunking for non-code files)
|
||||||
|
python -m apps.document_rag --enable-code-chunking --chunk-size 256 --ast-chunk-size 768
|
||||||
|
```
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [astchunk GitHub Repository](https://github.com/yilinjz/astchunk)
|
||||||
|
- [LEANN MCP Integration](../packages/leann-mcp/README.md)
|
||||||
|
- [Research Paper](https://arxiv.org/html/2506.15655v1)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Note**: AST chunking maintains full backward compatibility while enhancing code understanding capabilities.
|
||||||
98
docs/code/embedding_model_compare.py
Normal file
98
docs/code/embedding_model_compare.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Comparison between Sentence Transformers and OpenAI embeddings
|
||||||
|
|
||||||
|
This example shows how different embedding models handle complex queries
|
||||||
|
and demonstrates the differences between local and API-based embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
|
||||||
|
# OpenAI API key should be set as environment variable
|
||||||
|
# export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
conference_text = "[Title]: COLING 2025 Conference\n[URL]: https://coling2025.org/"
|
||||||
|
browser_text = "[Title]: Browser Use Tool\n[URL]: https://github.com/browser-use"
|
||||||
|
|
||||||
|
# Two queries with same intent but different wording
|
||||||
|
query1 = "Tell me my browser history about some conference i often visit"
|
||||||
|
query2 = "browser history about conference I often visit"
|
||||||
|
|
||||||
|
texts = [query1, query2, conference_text, browser_text]
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
return np.dot(a, b) # Already normalized
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_embeddings(embeddings, model_name):
|
||||||
|
print(f"\n=== {model_name} Results ===")
|
||||||
|
|
||||||
|
# Results for Query 1
|
||||||
|
sim1_conf = cosine_similarity(embeddings[0], embeddings[2])
|
||||||
|
sim1_browser = cosine_similarity(embeddings[0], embeddings[3])
|
||||||
|
|
||||||
|
print(f"Query 1: '{query1}'")
|
||||||
|
print(f" → Conference similarity: {sim1_conf:.4f} {'✓' if sim1_conf > sim1_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim1_browser:.4f} {'✓' if sim1_browser > sim1_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim1_conf > sim1_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Results for Query 2
|
||||||
|
sim2_conf = cosine_similarity(embeddings[1], embeddings[2])
|
||||||
|
sim2_browser = cosine_similarity(embeddings[1], embeddings[3])
|
||||||
|
|
||||||
|
print(f"\nQuery 2: '{query2}'")
|
||||||
|
print(f" → Conference similarity: {sim2_conf:.4f} {'✓' if sim2_conf > sim2_browser else ''}")
|
||||||
|
print(
|
||||||
|
f" → Browser similarity: {sim2_browser:.4f} {'✓' if sim2_browser > sim2_conf else ''}"
|
||||||
|
)
|
||||||
|
print(f" Winner: {'Conference' if sim2_conf > sim2_browser else 'Browser'}")
|
||||||
|
|
||||||
|
# Show the impact
|
||||||
|
print("\n=== Impact Analysis ===")
|
||||||
|
print(f"Conference similarity change: {sim2_conf - sim1_conf:+.4f}")
|
||||||
|
print(f"Browser similarity change: {sim2_browser - sim1_browser:+.4f}")
|
||||||
|
|
||||||
|
if sim1_conf > sim1_browser and sim2_browser > sim2_conf:
|
||||||
|
print("❌ FLIP: Adding 'browser history' flips winner from Conference to Browser!")
|
||||||
|
elif sim1_conf > sim1_browser and sim2_conf > sim2_browser:
|
||||||
|
print("✅ STABLE: Conference remains winner in both queries")
|
||||||
|
elif sim1_browser > sim1_conf and sim2_browser > sim2_conf:
|
||||||
|
print("✅ STABLE: Browser remains winner in both queries")
|
||||||
|
else:
|
||||||
|
print("🔄 MIXED: Results vary between queries")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query1_conf": sim1_conf,
|
||||||
|
"query1_browser": sim1_browser,
|
||||||
|
"query2_conf": sim2_conf,
|
||||||
|
"query2_browser": sim2_browser,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Test Sentence Transformers
|
||||||
|
print("Testing Sentence Transformers (facebook/contriever)...")
|
||||||
|
try:
|
||||||
|
st_embeddings = compute_embeddings(texts, "facebook/contriever", mode="sentence-transformers")
|
||||||
|
st_results = analyze_embeddings(st_embeddings, "Sentence Transformers (facebook/contriever)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Sentence Transformers failed: {e}")
|
||||||
|
st_results = None
|
||||||
|
|
||||||
|
# Test OpenAI
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Testing OpenAI (text-embedding-3-small)...")
|
||||||
|
try:
|
||||||
|
openai_embeddings = compute_embeddings(texts, "text-embedding-3-small", mode="openai")
|
||||||
|
openai_results = analyze_embeddings(openai_embeddings, "OpenAI (text-embedding-3-small)")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ OpenAI failed: {e}")
|
||||||
|
openai_results = None
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
if st_results and openai_results:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("=== COMPARISON SUMMARY ===")
|
||||||
459
docs/configuration-guide.md
Normal file
459
docs/configuration-guide.md
Normal file
@@ -0,0 +1,459 @@
|
|||||||
|
# LEANN Configuration Guide
|
||||||
|
|
||||||
|
This guide helps you optimize LEANN for different use cases and understand the trade-offs between various configuration options.
|
||||||
|
|
||||||
|
## Getting Started: Simple is Better
|
||||||
|
|
||||||
|
When first trying LEANN, start with a small dataset to quickly validate your approach:
|
||||||
|
|
||||||
|
**For document RAG**: The default `data/` directory works perfectly - includes 2 AI research papers, Pride and Prejudice literature, and a technical report
|
||||||
|
```bash
|
||||||
|
python -m apps.document_rag --query "What techniques does LEANN use?"
|
||||||
|
```
|
||||||
|
|
||||||
|
**For other data sources**: Limit the dataset size for quick testing
|
||||||
|
```bash
|
||||||
|
# WeChat: Test with recent messages only
|
||||||
|
python -m apps.wechat_rag --max-items 100 --query "What did we discuss about the project timeline?"
|
||||||
|
|
||||||
|
# Browser history: Last few days
|
||||||
|
python -m apps.browser_rag --max-items 500 --query "Find documentation about vector databases"
|
||||||
|
|
||||||
|
# Email: Recent inbox
|
||||||
|
python -m apps.email_rag --max-items 200 --query "Who sent updates about the deployment status?"
|
||||||
|
```
|
||||||
|
|
||||||
|
Once validated, scale up gradually:
|
||||||
|
- 100 documents → 1,000 → 10,000 → full dataset (`--max-items -1`)
|
||||||
|
- This helps identify issues early before committing to long processing times
|
||||||
|
|
||||||
|
## Embedding Model Selection: Understanding the Trade-offs
|
||||||
|
|
||||||
|
Based on our experience developing LEANN, embedding models fall into three categories:
|
||||||
|
|
||||||
|
### Small Models (< 100M parameters)
|
||||||
|
**Example**: `sentence-transformers/all-MiniLM-L6-v2` (22M params)
|
||||||
|
- **Pros**: Lightweight, fast for both indexing and inference
|
||||||
|
- **Cons**: Lower semantic understanding, may miss nuanced relationships
|
||||||
|
- **Use when**: Speed is critical, handling simple queries, interactive mode, or just experimenting with LEANN. If time is not a constraint, consider using a larger/better embedding model
|
||||||
|
|
||||||
|
### Medium Models (100M-500M parameters)
|
||||||
|
**Example**: `facebook/contriever` (110M params), `BAAI/bge-base-en-v1.5` (110M params)
|
||||||
|
- **Pros**: Balanced performance, good multilingual support, reasonable speed
|
||||||
|
- **Cons**: Requires more compute than small models
|
||||||
|
- **Use when**: Need quality results without extreme compute requirements, general-purpose RAG applications
|
||||||
|
|
||||||
|
### Large Models (500M+ parameters)
|
||||||
|
**Example**: `Qwen/Qwen3-Embedding-0.6B` (600M params), `intfloat/multilingual-e5-large` (560M params)
|
||||||
|
- **Pros**: Best semantic understanding, captures complex relationships, excellent multilingual support. **Qwen3-Embedding-0.6B achieves nearly OpenAI API performance!**
|
||||||
|
- **Cons**: Slower inference, longer index build times
|
||||||
|
- **Use when**: Quality is paramount and you have sufficient compute resources. **Highly recommended** for production use
|
||||||
|
|
||||||
|
### Quick Start: Cloud and Local Embedding Options
|
||||||
|
|
||||||
|
**OpenAI Embeddings (Fastest Setup)**
|
||||||
|
For immediate testing without local model downloads(also if you [do not have GPU](https://github.com/yichuan-w/LEANN/issues/43) and do not care that much about your document leak, you should use this, we compute the embedding and recompute using openai API):
|
||||||
|
```bash
|
||||||
|
# Set OpenAI embeddings (requires OPENAI_API_KEY)
|
||||||
|
--embedding-mode openai --embedding-model text-embedding-3-small
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ollama Embeddings (Privacy-Focused)**
|
||||||
|
For local embeddings with complete privacy:
|
||||||
|
```bash
|
||||||
|
# First, pull an embedding model
|
||||||
|
ollama pull nomic-embed-text
|
||||||
|
|
||||||
|
# Use Ollama embeddings
|
||||||
|
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>Cloud vs Local Trade-offs</strong></summary>
|
||||||
|
|
||||||
|
**OpenAI Embeddings** (`text-embedding-3-small/large`)
|
||||||
|
- **Pros**: No local compute needed, consistently fast, high quality
|
||||||
|
- **Cons**: Requires API key, costs money, data leaves your system, [known limitations with certain languages](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
|
- **When to use**: Prototyping, non-sensitive data, need immediate results
|
||||||
|
|
||||||
|
**Local Embeddings**
|
||||||
|
- **Pros**: Complete privacy, no ongoing costs, full control, can sometimes outperform OpenAI embeddings
|
||||||
|
- **Cons**: Slower than cloud APIs, requires local compute resources
|
||||||
|
- **When to use**: Production systems, sensitive data, cost-sensitive applications
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Local & Remote Inference Endpoints
|
||||||
|
|
||||||
|
> Applies to both LLMs (`leann ask`) and embeddings (`leann build`).
|
||||||
|
|
||||||
|
LEANN now treats Ollama, LM Studio, and other OpenAI-compatible runtimes as first-class providers. You can point LEANN at any compatible endpoint – either on the same machine or across the network – with a couple of flags or environment variables.
|
||||||
|
|
||||||
|
### One-Time Environment Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Works for OpenAI-compatible runtimes such as LM Studio, vLLM, SGLang, llamafile, etc.
|
||||||
|
export OPENAI_API_KEY="your-key" # or leave unset for local servers that do not check keys
|
||||||
|
export OPENAI_BASE_URL="http://localhost:1234/v1"
|
||||||
|
|
||||||
|
# Ollama-compatible runtimes (Ollama, Ollama on another host, llamacpp-server, etc.)
|
||||||
|
export LEANN_OLLAMA_HOST="http://localhost:11434" # falls back to OLLAMA_HOST or LOCAL_LLM_ENDPOINT
|
||||||
|
```
|
||||||
|
|
||||||
|
LEANN also recognises `LEANN_LOCAL_LLM_HOST` (highest priority), `LEANN_OPENAI_BASE_URL`, and `LOCAL_OPENAI_BASE_URL`, so existing scripts continue to work.
|
||||||
|
|
||||||
|
### Passing Hosts Per Command
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build an index with a remote embedding server
|
||||||
|
leann build my-notes \
|
||||||
|
--docs ./notes \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-qwen3-embedding-0.6b \
|
||||||
|
--embedding-api-base http://192.168.1.50:1234/v1 \
|
||||||
|
--embedding-api-key local-dev-key
|
||||||
|
|
||||||
|
# Query using a local LM Studio instance via OpenAI-compatible API
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm openai \
|
||||||
|
--llm-model qwen3-8b \
|
||||||
|
--api-base http://localhost:1234/v1 \
|
||||||
|
--api-key local-dev-key
|
||||||
|
|
||||||
|
# Query an Ollama instance running on another box
|
||||||
|
leann ask my-notes \
|
||||||
|
--llm ollama \
|
||||||
|
--llm-model qwen3:14b \
|
||||||
|
--host http://192.168.1.101:11434
|
||||||
|
```
|
||||||
|
|
||||||
|
⚠️ **Make sure the endpoint is reachable**: when your inference server runs on a home/workstation and the index/search job runs in the cloud, the server must be able to reach the host you configured. Typical options include:
|
||||||
|
|
||||||
|
- Expose a public IP (and open the relevant port) on the machine that hosts LM Studio/Ollama.
|
||||||
|
- Configure router or cloud provider port forwarding.
|
||||||
|
- Tunnel traffic through tools like `tailscale`, `cloudflared`, or `ssh -R`.
|
||||||
|
|
||||||
|
When you set these options while building an index, LEANN stores them in `meta.json`. Any subsequent `leann ask` or searcher process automatically reuses the same provider settings – even when we spawn background embedding servers. This makes the “server without GPU talking to my local workstation” workflow from [issue #80](https://github.com/yichuan-w/LEANN/issues/80#issuecomment-2287230548) work out-of-the-box.
|
||||||
|
|
||||||
|
**Tip:** If your runtime does not require an API key (many local stacks don’t), leave `--api-key` unset. LEANN will skip injecting credentials.
|
||||||
|
|
||||||
|
### Python API Usage
|
||||||
|
|
||||||
|
You can pass the same configuration from Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_mode="openai",
|
||||||
|
embedding_model="text-embedding-qwen3-embedding-0.6b",
|
||||||
|
embedding_options={
|
||||||
|
"base_url": "http://192.168.1.50:1234/v1",
|
||||||
|
"api_key": "local-dev-key",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.build_index("./indexes/my-notes", chunks)
|
||||||
|
```
|
||||||
|
|
||||||
|
`embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you).
|
||||||
|
|
||||||
|
## Index Selection: Matching Your Scale
|
||||||
|
|
||||||
|
### HNSW (Hierarchical Navigable Small World)
|
||||||
|
**Best for**: Small to medium datasets (< 10M vectors) - **Default and recommended for extreme low storage**
|
||||||
|
- Full recomputation required
|
||||||
|
- High memory usage during build phase
|
||||||
|
- Excellent recall (95%+)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Optimal for most use cases
|
||||||
|
--backend-name hnsw --graph-degree 32 --build-complexity 64
|
||||||
|
```
|
||||||
|
|
||||||
|
### DiskANN
|
||||||
|
**Best for**: Large datasets, especially when you want `recompute=True`.
|
||||||
|
|
||||||
|
**Key advantages:**
|
||||||
|
- **Faster search** on large datasets (3x+ speedup vs HNSW in many cases)
|
||||||
|
- **Smart storage**: `recompute=True` enables automatic graph partitioning for smaller indexes
|
||||||
|
- **Better scaling**: Designed for 100k+ documents
|
||||||
|
|
||||||
|
**Recompute behavior:**
|
||||||
|
- `recompute=True` (recommended): Pure PQ traversal + final reranking - faster and enables partitioning
|
||||||
|
- `recompute=False`: PQ + partial real distances during traversal - slower but higher accuracy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Recommended for most use cases
|
||||||
|
--backend-name diskann --graph-degree 32 --build-complexity 64
|
||||||
|
```
|
||||||
|
|
||||||
|
**Performance Benchmark**: Run `uv run benchmarks/diskann_vs_hnsw_speed_comparison.py` to compare DiskANN and HNSW on your system.
|
||||||
|
|
||||||
|
## LLM Selection: Engine and Model Comparison
|
||||||
|
|
||||||
|
### LLM Engines
|
||||||
|
|
||||||
|
**OpenAI** (`--llm openai`)
|
||||||
|
- **Pros**: Best quality, consistent performance, no local resources needed
|
||||||
|
- **Cons**: Costs money ($0.15-2.5 per million tokens), requires internet, data privacy concerns
|
||||||
|
- **Models**: `gpt-4o-mini` (fast, cheap), `gpt-4o` (best quality), `o3` (reasoning), `o3-mini` (reasoning, cheaper)
|
||||||
|
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for o-series reasoning models (o3, o3-mini, o4-mini)
|
||||||
|
- **Note**: Our current default, but we recommend switching to Ollama for most use cases
|
||||||
|
|
||||||
|
**Ollama** (`--llm ollama`)
|
||||||
|
- **Pros**: Fully local, free, privacy-preserving, good model variety
|
||||||
|
- **Cons**: Requires local GPU/CPU resources, slower than cloud APIs, need to install extra [ollama app](https://github.com/ollama/ollama?tab=readme-ov-file#ollama) and pre-download models by `ollama pull`
|
||||||
|
- **Models**: `qwen3:0.6b` (ultra-fast), `qwen3:1.7b` (balanced), `qwen3:4b` (good quality), `qwen3:7b` (high quality), `deepseek-r1:1.5b` (reasoning)
|
||||||
|
- **Thinking Budget**: Use `--thinking-budget low/medium/high` for reasoning models like GPT-Oss:20b
|
||||||
|
|
||||||
|
**HuggingFace** (`--llm hf`)
|
||||||
|
- **Pros**: Free tier available, huge model selection, direct model loading (vs Ollama's server-based approach)
|
||||||
|
- **Cons**: More complex initial setup
|
||||||
|
- **Models**: `Qwen/Qwen3-1.7B-FP8`
|
||||||
|
|
||||||
|
## Parameter Tuning Guide
|
||||||
|
|
||||||
|
### Search Complexity Parameters
|
||||||
|
|
||||||
|
**`--build-complexity`** (index building)
|
||||||
|
- Controls thoroughness during index construction
|
||||||
|
- Higher = better recall but slower build
|
||||||
|
- Recommendations:
|
||||||
|
- 32: Quick prototyping
|
||||||
|
- 64: Balanced (default)
|
||||||
|
- 128: Production systems
|
||||||
|
- 256: Maximum quality
|
||||||
|
|
||||||
|
**`--search-complexity`** (query time)
|
||||||
|
- Controls search thoroughness
|
||||||
|
- Higher = better results but slower
|
||||||
|
- Recommendations:
|
||||||
|
- 16: Fast/Interactive search
|
||||||
|
- 32: High quality with diversity
|
||||||
|
- 64+: Maximum accuracy
|
||||||
|
|
||||||
|
### Top-K Selection
|
||||||
|
|
||||||
|
**`--top-k`** (number of retrieved chunks)
|
||||||
|
- More chunks = better context but slower LLM processing
|
||||||
|
- Should be always smaller than `--search-complexity`
|
||||||
|
- Guidelines:
|
||||||
|
- 10-20: General questions (default: 20)
|
||||||
|
- 30+: Complex multi-hop reasoning requiring comprehensive context
|
||||||
|
|
||||||
|
**Trade-off formula**:
|
||||||
|
- Retrieval time ∝ log(n) × search_complexity
|
||||||
|
- LLM processing time ∝ top_k × chunk_size
|
||||||
|
- Total context = top_k × chunk_size tokens
|
||||||
|
|
||||||
|
### Thinking Budget for Reasoning Models
|
||||||
|
|
||||||
|
**`--thinking-budget`** (reasoning effort level)
|
||||||
|
- Controls the computational effort for reasoning models
|
||||||
|
- Options: `low`, `medium`, `high`
|
||||||
|
- Guidelines:
|
||||||
|
- `low`: Fast responses, basic reasoning (default for simple queries)
|
||||||
|
- `medium`: Balanced speed and reasoning depth
|
||||||
|
- `high`: Maximum reasoning effort, best for complex analytical questions
|
||||||
|
- **Supported Models**:
|
||||||
|
- **Ollama**: `gpt-oss:20b`, `gpt-oss:120b`
|
||||||
|
- **OpenAI**: `o3`, `o3-mini`, `o4-mini`, `o1` (o-series reasoning models)
|
||||||
|
- **Note**: Models without reasoning support will show a warning and proceed without reasoning parameters
|
||||||
|
- **Example**: `--thinking-budget high` for complex analytical questions
|
||||||
|
|
||||||
|
**📖 For detailed usage examples and implementation details, check out [Thinking Budget Documentation](THINKING_BUDGET_FEATURE.md)**
|
||||||
|
|
||||||
|
**💡 Quick Examples:**
|
||||||
|
```bash
|
||||||
|
# OpenAI o-series reasoning model
|
||||||
|
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||||
|
--index-dir hnswbuild --backend hnsw \
|
||||||
|
--llm openai --llm-model o3 --thinking-budget medium
|
||||||
|
|
||||||
|
# Ollama reasoning model
|
||||||
|
python apps/document_rag.py --query "What are the main techniques LEANN explores?" \
|
||||||
|
--index-dir hnswbuild --backend hnsw \
|
||||||
|
--llm ollama --llm-model gpt-oss:20b --thinking-budget high
|
||||||
|
```
|
||||||
|
|
||||||
|
### Graph Degree (HNSW/DiskANN)
|
||||||
|
|
||||||
|
**`--graph-degree`**
|
||||||
|
- Number of connections per node in the graph
|
||||||
|
- Higher = better recall but more memory
|
||||||
|
- HNSW: 16-32 (default: 32)
|
||||||
|
- DiskANN: 32-128 (default: 64)
|
||||||
|
|
||||||
|
|
||||||
|
## Performance Optimization Checklist
|
||||||
|
|
||||||
|
### If Embedding is Too Slow
|
||||||
|
|
||||||
|
1. **Switch to smaller model**:
|
||||||
|
```bash
|
||||||
|
# From large model
|
||||||
|
--embedding-model Qwen/Qwen3-Embedding-0.6B
|
||||||
|
# To small model
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Limit dataset size for testing**:
|
||||||
|
```bash
|
||||||
|
--max-items 1000 # Process first 1k items only
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Use MLX on Apple Silicon** (optional optimization):
|
||||||
|
```bash
|
||||||
|
--embedding-mode mlx --embedding-model mlx-community/Qwen3-Embedding-0.6B-8bit
|
||||||
|
```
|
||||||
|
MLX might not be the best choice, as we tested and found that it only offers 1.3x acceleration compared to HF, so maybe using ollama is a better choice for embedding generation
|
||||||
|
|
||||||
|
4. **Use Ollama**
|
||||||
|
```bash
|
||||||
|
--embedding-mode ollama --embedding-model nomic-embed-text
|
||||||
|
```
|
||||||
|
To discover additional embedding models in ollama, check out https://ollama.com/search?c=embedding or read more about embedding models at https://ollama.com/blog/embedding-models, please do check the model size that works best for you
|
||||||
|
### If Search Quality is Poor
|
||||||
|
|
||||||
|
1. **Increase retrieval count**:
|
||||||
|
```bash
|
||||||
|
--top-k 30 # Retrieve more candidates
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Upgrade embedding model**:
|
||||||
|
```bash
|
||||||
|
# For English
|
||||||
|
--embedding-model BAAI/bge-base-en-v1.5
|
||||||
|
# For multilingual
|
||||||
|
--embedding-model intfloat/multilingual-e5-large
|
||||||
|
```
|
||||||
|
|
||||||
|
## Understanding the Trade-offs
|
||||||
|
|
||||||
|
Every configuration choice involves trade-offs:
|
||||||
|
|
||||||
|
| Factor | Small/Fast | Large/Quality |
|
||||||
|
|--------|------------|---------------|
|
||||||
|
| Embedding Model | `all-MiniLM-L6-v2` | `Qwen/Qwen3-Embedding-0.6B` |
|
||||||
|
| Chunk Size | 512 tokens | 128 tokens |
|
||||||
|
| Index Type | HNSW | DiskANN |
|
||||||
|
| LLM | `qwen3:1.7b` | `gpt-4o` |
|
||||||
|
|
||||||
|
The key is finding the right balance for your specific use case. Start small and simple, measure performance, then scale up only where needed.
|
||||||
|
|
||||||
|
## Low-resource setups
|
||||||
|
|
||||||
|
If you don’t have a local GPU or builds/searches are too slow, use one or more of the options below.
|
||||||
|
|
||||||
|
### 1) Use OpenAI embeddings (no local compute)
|
||||||
|
|
||||||
|
Fastest path with zero local GPU requirements. Set your API key and use OpenAI embeddings during build and search:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export OPENAI_API_KEY=sk-...
|
||||||
|
|
||||||
|
# Build with OpenAI embeddings
|
||||||
|
leann build my-index \
|
||||||
|
--embedding-mode openai \
|
||||||
|
--embedding-model text-embedding-3-small
|
||||||
|
|
||||||
|
# Search with OpenAI embeddings (recompute at query time)
|
||||||
|
leann search my-index "your query" \
|
||||||
|
--recompute
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2) Run remote builds with SkyPilot (cloud GPU)
|
||||||
|
|
||||||
|
Offload embedding generation and index building to a GPU VM using [SkyPilot](https://skypilot.readthedocs.io/en/latest/). A template is provided at `sky/leann-build.yaml`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# One-time: install and configure SkyPilot
|
||||||
|
pip install skypilot
|
||||||
|
|
||||||
|
# Launch with defaults (L4:1) and mount ./data to ~/leann-data; the build runs automatically
|
||||||
|
sky launch -c leann-gpu sky/leann-build.yaml
|
||||||
|
|
||||||
|
# Override parameters via -e key=value (optional)
|
||||||
|
sky launch -c leann-gpu sky/leann-build.yaml \
|
||||||
|
-e index_name=my-index \
|
||||||
|
-e backend=hnsw \
|
||||||
|
-e embedding_mode=sentence-transformers \
|
||||||
|
-e embedding_model=Qwen/Qwen3-Embedding-0.6B
|
||||||
|
|
||||||
|
# Copy the built index back to your local .leann (use rsync)
|
||||||
|
rsync -Pavz leann-gpu:~/.leann/indexes/my-index ./.leann/indexes/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3) Disable recomputation to trade storage for speed
|
||||||
|
|
||||||
|
If you need lower latency and have more storage/memory, disable recomputation. This stores full embeddings and avoids recomputing at search time.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build without recomputation (HNSW requires non-compact in this mode)
|
||||||
|
leann build my-index --no-recompute --no-compact
|
||||||
|
|
||||||
|
# Search without recomputation
|
||||||
|
leann search my-index "your query" --no-recompute
|
||||||
|
```
|
||||||
|
|
||||||
|
When to use:
|
||||||
|
- Extreme low latency requirements (high QPS, interactive assistants)
|
||||||
|
- Read-heavy workloads where storage is cheaper than latency
|
||||||
|
- No always-available GPU
|
||||||
|
|
||||||
|
Constraints:
|
||||||
|
- HNSW: when `--no-recompute` is set, LEANN automatically disables compact mode during build
|
||||||
|
- DiskANN: supported; `--no-recompute` skips selective recompute during search
|
||||||
|
|
||||||
|
Storage impact:
|
||||||
|
- Storing N embeddings of dimension D with float32 requires approximately N × D × 4 bytes
|
||||||
|
- Example: 1,000,000 chunks × 768 dims × 4 bytes ≈ 2.86 GB (plus graph/metadata)
|
||||||
|
|
||||||
|
Converting an existing index (rebuild required):
|
||||||
|
```bash
|
||||||
|
# Rebuild in-place (ensure you still have original docs or can regenerate chunks)
|
||||||
|
leann build my-index --force --no-recompute --no-compact
|
||||||
|
```
|
||||||
|
|
||||||
|
Python API usage:
|
||||||
|
```python
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("/path/to/my-index.leann")
|
||||||
|
results = searcher.search("your query", top_k=10, recompute_embeddings=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
Trade-offs:
|
||||||
|
- Lower latency and fewer network hops at query time
|
||||||
|
- Significantly higher storage (10–100× vs selective recomputation)
|
||||||
|
- Slightly larger memory footprint during build and search
|
||||||
|
|
||||||
|
Quick benchmark results (`benchmarks/benchmark_no_recompute.py` with 5k texts, complexity=32):
|
||||||
|
|
||||||
|
- HNSW
|
||||||
|
|
||||||
|
```text
|
||||||
|
recompute=True: search_time=0.818s, size=1.1MB
|
||||||
|
recompute=False: search_time=0.012s, size=16.6MB
|
||||||
|
```
|
||||||
|
|
||||||
|
- DiskANN
|
||||||
|
|
||||||
|
```text
|
||||||
|
recompute=True: search_time=0.041s, size=5.9MB
|
||||||
|
recompute=False: search_time=0.013s, size=24.6MB
|
||||||
|
```
|
||||||
|
|
||||||
|
Conclusion:
|
||||||
|
- **HNSW**: `no-recompute` is significantly faster (no embedding recomputation) but requires much more storage (stores all embeddings)
|
||||||
|
- **DiskANN**: `no-recompute` uses PQ + partial real distances during traversal (slower but higher accuracy), while `recompute=True` uses pure PQ traversal + final reranking (faster traversal, enables build-time partitioning for smaller storage)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Further Reading
|
||||||
|
|
||||||
|
- [Lessons Learned Developing LEANN](https://yichuan-w.github.io/blog/lessons_learned_in_dev_leann/)
|
||||||
|
- [LEANN Technical Paper](https://arxiv.org/abs/2506.08276)
|
||||||
|
- [DiskANN Original Paper](https://papers.nips.cc/paper/2019/file/09853c7fb1d3f8ee67a61b6bf4a7f8e6-Paper.pdf)
|
||||||
|
- [SSD-based Graph Partitioning](https://github.com/SonglinLife/SSD_BASED_PLAN)
|
||||||
10
docs/faq.md
Normal file
10
docs/faq.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# FAQ
|
||||||
|
|
||||||
|
## 1. My building time seems long
|
||||||
|
|
||||||
|
You can speed up the process by using a lightweight embedding model. Add this to your arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--embedding-model sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
```
|
||||||
|
**Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters)
|
||||||
23
docs/features.md
Normal file
23
docs/features.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# ✨ Detailed Features
|
||||||
|
|
||||||
|
## 🔥 Core Features
|
||||||
|
|
||||||
|
- **🔄 Real-time Embeddings** - Eliminate heavy embedding storage with dynamic computation using optimized ZMQ servers and highly optimized search paradigm (overlapping and batching) with highly optimized embedding engine
|
||||||
|
- **🧠 AST-Aware Code Chunking** - Intelligent code chunking that preserves semantic boundaries (functions, classes, methods) for Python, Java, C#, and TypeScript files
|
||||||
|
- **📈 Scalable Architecture** - Handles millions of documents on consumer hardware; the larger your dataset, the more LEANN can save
|
||||||
|
- **🎯 Graph Pruning** - Advanced techniques to minimize the storage overhead of vector search to a limited footprint
|
||||||
|
- **🏗️ Pluggable Backends** - HNSW/FAISS (default), with optional DiskANN for large-scale deployments
|
||||||
|
|
||||||
|
## 🛠️ Technical Highlights
|
||||||
|
- **🔄 Recompute Mode** - Highest accuracy scenarios while eliminating vector storage overhead
|
||||||
|
- **⚡ Zero-copy Operations** - Minimize IPC overhead by transferring distances instead of embeddings
|
||||||
|
- **🚀 High-throughput Embedding Pipeline** - Optimized batched processing for maximum efficiency
|
||||||
|
- **🎯 Two-level Search** - Novel coarse-to-fine search overlap for accelerated query processing (optional)
|
||||||
|
- **💾 Memory-mapped Indices** - Fast startup with raw text mapping to reduce memory overhead
|
||||||
|
- **🚀 MLX Support** - Ultra-fast recompute/build with quantized embedding models, accelerating building and search ([minimal example](../examples/mlx_demo.py))
|
||||||
|
|
||||||
|
## 🎨 Developer Experience
|
||||||
|
|
||||||
|
- **Simple Python API** - Get started in minutes
|
||||||
|
- **Extensible backend system** - Easy to add new algorithms
|
||||||
|
- **Comprehensive examples** - From basic usage to production deployment
|
||||||
149
docs/grep_search.md
Normal file
149
docs/grep_search.md
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# LEANN Grep Search Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LEANN's grep search functionality provides exact text matching for finding specific code patterns, error messages, function names, or exact phrases in your indexed documents.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Simple Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("your_index_path")
|
||||||
|
|
||||||
|
# Exact text search
|
||||||
|
results = searcher.search("def authenticate_user", use_grep=True, top_k=5)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:100]}...")
|
||||||
|
print("-" * 40)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Comparison: Semantic vs Grep Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Semantic search - finds conceptually similar content
|
||||||
|
semantic_results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
|
||||||
|
# Grep search - finds exact text matches
|
||||||
|
grep_results = searcher.search("def train_model", use_grep=True, top_k=3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## When to Use Grep Search
|
||||||
|
|
||||||
|
### Use Cases
|
||||||
|
|
||||||
|
- **Code Search**: Finding specific function definitions, class names, or variable references
|
||||||
|
- **Error Debugging**: Locating exact error messages or stack traces
|
||||||
|
- **Documentation**: Finding specific API endpoints or exact terminology
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find function definitions
|
||||||
|
functions = searcher.search("def __init__", use_grep=True)
|
||||||
|
|
||||||
|
# Find import statements
|
||||||
|
imports = searcher.search("from sklearn import", use_grep=True)
|
||||||
|
|
||||||
|
# Find specific error types
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
|
||||||
|
# Find TODO comments
|
||||||
|
todos = searcher.search("TODO:", use_grep=True)
|
||||||
|
|
||||||
|
# Find configuration entries
|
||||||
|
configs = searcher.search("server_port=", use_grep=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technical Details
|
||||||
|
|
||||||
|
### How It Works
|
||||||
|
|
||||||
|
1. **File Location**: Grep search operates on the raw text stored in `.jsonl` files
|
||||||
|
2. **Command Execution**: Uses the system `grep` command with case-insensitive search
|
||||||
|
3. **Result Processing**: Parses JSON lines and extracts text and metadata
|
||||||
|
4. **Scoring**: Simple frequency-based scoring based on query term occurrences
|
||||||
|
|
||||||
|
### Search Process
|
||||||
|
|
||||||
|
```
|
||||||
|
Query: "def train_model"
|
||||||
|
↓
|
||||||
|
grep -i -n "def train_model" documents.leann.passages.jsonl
|
||||||
|
↓
|
||||||
|
Parse matching JSON lines
|
||||||
|
↓
|
||||||
|
Calculate scores based on term frequency
|
||||||
|
↓
|
||||||
|
Return top_k results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Scoring Algorithm
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Term frequency in document
|
||||||
|
score = text.lower().count(query.lower())
|
||||||
|
```
|
||||||
|
|
||||||
|
Results are ranked by score (highest first), with higher scores indicating more occurrences of the search term.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
#### Grep Command Not Found
|
||||||
|
```
|
||||||
|
RuntimeError: grep command not found. Please install grep or use semantic search.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Solution**: Install grep on your system:
|
||||||
|
- **Ubuntu/Debian**: `sudo apt-get install grep`
|
||||||
|
- **macOS**: grep is pre-installed
|
||||||
|
- **Windows**: Use WSL or install grep via Git Bash/MSYS2
|
||||||
|
|
||||||
|
#### No Results Found
|
||||||
|
```python
|
||||||
|
# Check if your query exists in the raw data
|
||||||
|
results = searcher.search("your_query", use_grep=True)
|
||||||
|
if not results:
|
||||||
|
print("No exact matches found. Try:")
|
||||||
|
print("1. Check spelling and case")
|
||||||
|
print("2. Use partial terms")
|
||||||
|
print("3. Switch to semantic search")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example
|
||||||
|
|
||||||
|
```python
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
Demonstrates grep search for exact text matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
def demonstrate_grep_search():
|
||||||
|
# Initialize searcher
|
||||||
|
searcher = LeannSearcher("my_index")
|
||||||
|
|
||||||
|
print("=== Function Search ===")
|
||||||
|
functions = searcher.search("def __init__", use_grep=True, top_k=5)
|
||||||
|
for i, result in enumerate(functions, 1):
|
||||||
|
print(f"{i}. Score: {result.score}")
|
||||||
|
print(f" Preview: {result.text[:60]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=== Error Search ===")
|
||||||
|
errors = searcher.search("FileNotFoundError", use_grep=True, top_k=3)
|
||||||
|
for result in errors:
|
||||||
|
print(f"Content: {result.text.strip()}")
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demonstrate_grep_search()
|
||||||
|
```
|
||||||
300
docs/metadata_filtering.md
Normal file
300
docs/metadata_filtering.md
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
# LEANN Metadata Filtering Usage Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Leann possesses metadata filtering capabilities that allow you to filter search results based on arbitrary metadata fields set during chunking. This feature enables use cases like spoiler-free book search, document filtering by date/type, code search by file type, and potentially much more.
|
||||||
|
|
||||||
|
## Basic Usage
|
||||||
|
|
||||||
|
### Adding Metadata to Your Documents
|
||||||
|
|
||||||
|
When building your index, add metadata to each text chunk:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
|
||||||
|
# Add text with metadata
|
||||||
|
builder.add_text(
|
||||||
|
text="Chapter 1: Alice falls down the rabbit hole",
|
||||||
|
metadata={
|
||||||
|
"chapter": 1,
|
||||||
|
"character": "Alice",
|
||||||
|
"themes": ["adventure", "curiosity"],
|
||||||
|
"word_count": 150
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.build_index("alice_in_wonderland_index")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Searching with Metadata Filters
|
||||||
|
|
||||||
|
Use the `metadata_filters` parameter in search calls:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannSearcher
|
||||||
|
|
||||||
|
searcher = LeannSearcher("alice_in_wonderland_index")
|
||||||
|
|
||||||
|
# Search with filters
|
||||||
|
results = searcher.search(
|
||||||
|
query="What happens to Alice?",
|
||||||
|
top_k=10,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": 5}, # Only chapters 1-5
|
||||||
|
"spoiler_level": {"!=": "high"} # No high spoilers
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Filter Syntax
|
||||||
|
|
||||||
|
### Basic Structure
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"field_name": {"operator": value},
|
||||||
|
"another_field": {"operator": value}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Operators
|
||||||
|
|
||||||
|
#### Comparison Operators
|
||||||
|
- `"=="`: Equal to
|
||||||
|
- `"!="`: Not equal to
|
||||||
|
- `"<"`: Less than
|
||||||
|
- `"<="`: Less than or equal
|
||||||
|
- `">"`: Greater than
|
||||||
|
- `">="`: Greater than or equal
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"chapter": {"==": 1}} # Exactly chapter 1
|
||||||
|
{"page": {">": 100}} # Pages after 100
|
||||||
|
{"rating": {">=": 4.0}} # Rating 4.0 or higher
|
||||||
|
{"word_count": {"<": 500}} # Short passages
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Membership Operators
|
||||||
|
- `"in"`: Value is in list
|
||||||
|
- `"not_in"`: Value is not in list
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"character": {"in": ["Alice", "Bob"]}} # Alice OR Bob
|
||||||
|
{"genre": {"not_in": ["horror", "thriller"]}} # Exclude genres
|
||||||
|
{"tags": {"in": ["fiction", "adventure"]}} # Any of these tags
|
||||||
|
```
|
||||||
|
|
||||||
|
#### String Operators
|
||||||
|
- `"contains"`: String contains substring
|
||||||
|
- `"starts_with"`: String starts with prefix
|
||||||
|
- `"ends_with"`: String ends with suffix
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"title": {"contains": "alice"}} # Title contains "alice"
|
||||||
|
{"filename": {"ends_with": ".py"}} # Python files
|
||||||
|
{"author": {"starts_with": "Dr."}} # Authors with "Dr." prefix
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Boolean Operators
|
||||||
|
- `"is_true"`: Field is truthy
|
||||||
|
- `"is_false"`: Field is falsy
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Examples
|
||||||
|
{"is_published": {"is_true": True}} # Published content
|
||||||
|
{"is_draft": {"is_false": False}} # Not drafts
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Operators on Same Field
|
||||||
|
|
||||||
|
You can apply multiple operators to the same field (AND logic):
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"word_count": {
|
||||||
|
">=": 100, # At least 100 words
|
||||||
|
"<=": 500 # At most 500 words
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compound Filters
|
||||||
|
|
||||||
|
Multiple fields are combined with AND logic:
|
||||||
|
|
||||||
|
```python
|
||||||
|
metadata_filters = {
|
||||||
|
"chapter": {"<=": 10}, # Up to chapter 10
|
||||||
|
"character": {"==": "Alice"}, # About Alice
|
||||||
|
"spoiler_level": {"!=": "high"} # No major spoilers
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Use Case Examples
|
||||||
|
|
||||||
|
### 1. Spoiler-Free Book Search
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Reader has only read up to chapter 5
|
||||||
|
def search_spoiler_free(query, max_chapter):
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
metadata_filters={
|
||||||
|
"chapter": {"<=": max_chapter},
|
||||||
|
"spoiler_level": {"in": ["none", "low"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = search_spoiler_free("What happens to Alice?", max_chapter=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Document Management by Date
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Find recent documents
|
||||||
|
recent_docs = searcher.search(
|
||||||
|
query="project updates",
|
||||||
|
metadata_filters={
|
||||||
|
"date": {">=": "2024-01-01"},
|
||||||
|
"document_type": {"==": "report"}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Code Search by File Type
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search only Python files
|
||||||
|
python_code = searcher.search(
|
||||||
|
query="authentication function",
|
||||||
|
metadata_filters={
|
||||||
|
"file_extension": {"==": ".py"},
|
||||||
|
"lines_of_code": {"<": 100}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Content Filtering by Audience
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Age-appropriate content
|
||||||
|
family_content = searcher.search(
|
||||||
|
query="adventure stories",
|
||||||
|
metadata_filters={
|
||||||
|
"age_rating": {"in": ["G", "PG"]},
|
||||||
|
"content_warnings": {"not_in": ["violence", "adult_themes"]}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Multi-Book Series Management
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Search across first 3 books only
|
||||||
|
early_series = searcher.search(
|
||||||
|
query="character development",
|
||||||
|
metadata_filters={
|
||||||
|
"series": {"==": "Harry Potter"},
|
||||||
|
"book_number": {"<=": 3}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running the Example
|
||||||
|
|
||||||
|
You can see metadata filtering in action with our spoiler-free book RAG example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Don't forget to set up the environment
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Set your OpenAI API key (required for embeddings, but you can update the example locally and use ollama instead)
|
||||||
|
export OPENAI_API_KEY="your-api-key-here"
|
||||||
|
|
||||||
|
# Run the spoiler-free book RAG example
|
||||||
|
uv run examples/spoiler_free_book_rag.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This example demonstrates:
|
||||||
|
- Building an index with metadata (chapter numbers, characters, themes, locations)
|
||||||
|
- Searching with filters to avoid spoilers (e.g., only show results up to chapter 5)
|
||||||
|
- Different scenarios for readers at various points in the book
|
||||||
|
|
||||||
|
The example uses Alice's Adventures in Wonderland as sample data and shows how you can search for information without revealing plot points from later chapters.
|
||||||
|
|
||||||
|
## Advanced Patterns
|
||||||
|
|
||||||
|
### Custom Chunking with metadata
|
||||||
|
|
||||||
|
```python
|
||||||
|
def chunk_book_with_metadata(book_text, book_info):
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
for chapter_num, chapter_text in parse_chapters(book_text):
|
||||||
|
# Extract entities, themes, etc.
|
||||||
|
characters = extract_characters(chapter_text)
|
||||||
|
themes = classify_themes(chapter_text)
|
||||||
|
spoiler_level = assess_spoiler_level(chapter_text, chapter_num)
|
||||||
|
|
||||||
|
# Create chunks with rich metadata
|
||||||
|
for paragraph in split_paragraphs(chapter_text):
|
||||||
|
chunks.append({
|
||||||
|
"text": paragraph,
|
||||||
|
"metadata": {
|
||||||
|
"book_title": book_info["title"],
|
||||||
|
"chapter": chapter_num,
|
||||||
|
"characters": characters,
|
||||||
|
"themes": themes,
|
||||||
|
"spoiler_level": spoiler_level,
|
||||||
|
"word_count": len(paragraph.split()),
|
||||||
|
"reading_level": calculate_reading_level(paragraph)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
### Efficient Filtering Strategies
|
||||||
|
|
||||||
|
1. **Post-search filtering**: Applies filters after vector search, which should be efficient for typical result sets (10-100 results).
|
||||||
|
|
||||||
|
2. **Metadata design**: Keep metadata fields simple and avoid deeply nested structures.
|
||||||
|
|
||||||
|
### Best Practices
|
||||||
|
|
||||||
|
1. **Consistent metadata schema**: Use consistent field names and value types across your documents.
|
||||||
|
|
||||||
|
2. **Reasonable metadata size**: Keep metadata reasonably sized to avoid storage overhead.
|
||||||
|
|
||||||
|
3. **Type consistency**: Use consistent data types for the same fields (e.g., always integers for chapter numbers).
|
||||||
|
|
||||||
|
4. **Index multiple granularities**: Consider chunking at different levels (paragraph, section, chapter) with appropriate metadata.
|
||||||
|
|
||||||
|
### Adding Metadata to Existing Indices
|
||||||
|
|
||||||
|
To add metadata filtering to existing indices, you'll need to rebuild them with metadata:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Read existing passages and add metadata
|
||||||
|
def add_metadata_to_existing_chunks(chunks):
|
||||||
|
for chunk in chunks:
|
||||||
|
# Extract or assign metadata based on content
|
||||||
|
chunk["metadata"] = extract_metadata(chunk["text"])
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
# Rebuild index with metadata
|
||||||
|
enhanced_chunks = add_metadata_to_existing_chunks(existing_chunks)
|
||||||
|
builder = LeannBuilder("hnsw")
|
||||||
|
for chunk in enhanced_chunks:
|
||||||
|
builder.add_text(chunk["text"], chunk["metadata"])
|
||||||
|
builder.build_index("enhanced_index")
|
||||||
|
```
|
||||||
75
docs/normalized_embeddings.md
Normal file
75
docs/normalized_embeddings.md
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Normalized Embeddings Support in LEANN
|
||||||
|
|
||||||
|
LEANN now automatically detects normalized embedding models and sets the appropriate distance metric for optimal performance.
|
||||||
|
|
||||||
|
## What are Normalized Embeddings?
|
||||||
|
|
||||||
|
Normalized embeddings are vectors with L2 norm = 1 (unit vectors). These embeddings are optimized for cosine similarity rather than Maximum Inner Product Search (MIPS).
|
||||||
|
|
||||||
|
## Automatic Detection
|
||||||
|
|
||||||
|
When you create a `LeannBuilder` instance with a normalized embedding model, LEANN will:
|
||||||
|
|
||||||
|
1. **Automatically set `distance_metric="cosine"`** if not specified
|
||||||
|
2. **Show a warning** if you manually specify a different distance metric
|
||||||
|
3. **Provide optimal search performance** with the correct metric
|
||||||
|
|
||||||
|
## Supported Normalized Embedding Models
|
||||||
|
|
||||||
|
### OpenAI
|
||||||
|
All OpenAI text embedding models are normalized:
|
||||||
|
- `text-embedding-ada-002`
|
||||||
|
- `text-embedding-3-small`
|
||||||
|
- `text-embedding-3-large`
|
||||||
|
|
||||||
|
### Voyage AI
|
||||||
|
All Voyage AI embedding models are normalized:
|
||||||
|
- `voyage-2`
|
||||||
|
- `voyage-3`
|
||||||
|
- `voyage-large-2`
|
||||||
|
- `voyage-multilingual-2`
|
||||||
|
- `voyage-code-2`
|
||||||
|
|
||||||
|
### Cohere
|
||||||
|
All Cohere embedding models are normalized:
|
||||||
|
- `embed-english-v3.0`
|
||||||
|
- `embed-multilingual-v3.0`
|
||||||
|
- `embed-english-light-v3.0`
|
||||||
|
- `embed-multilingual-light-v3.0`
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
# Automatic detection - will use cosine distance
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai"
|
||||||
|
)
|
||||||
|
# Warning: Detected normalized embeddings model 'text-embedding-3-small'...
|
||||||
|
# Automatically setting distance_metric='cosine'
|
||||||
|
|
||||||
|
# Manual override (not recommended)
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model="text-embedding-3-small",
|
||||||
|
embedding_mode="openai",
|
||||||
|
distance_metric="mips" # Will show warning
|
||||||
|
)
|
||||||
|
# Warning: Using 'mips' distance metric with normalized embeddings...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Non-Normalized Embeddings
|
||||||
|
|
||||||
|
Models like `facebook/contriever` and other sentence-transformers models that are not normalized will continue to use MIPS by default, which is optimal for them.
|
||||||
|
|
||||||
|
## Why This Matters
|
||||||
|
|
||||||
|
Using the wrong distance metric with normalized embeddings can lead to:
|
||||||
|
- **Poor search quality** due to HNSW's early termination with narrow score ranges
|
||||||
|
- **Incorrect ranking** of search results
|
||||||
|
- **Suboptimal performance** compared to using the correct metric
|
||||||
|
|
||||||
|
For more details on why this happens, see our analysis in the [embedding detection code](../packages/leann-core/src/leann/api.py) which automatically handles normalized embeddings and MIPS distance metric issues.
|
||||||
21
docs/roadmap.md
Normal file
21
docs/roadmap.md
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# 📈 Roadmap
|
||||||
|
|
||||||
|
## 🎯 Q2 2025
|
||||||
|
|
||||||
|
- [X] HNSW backend integration
|
||||||
|
- [X] DiskANN backend with MIPS/L2/Cosine support
|
||||||
|
- [X] Real-time embedding pipeline
|
||||||
|
- [X] Memory-efficient graph pruning
|
||||||
|
|
||||||
|
## 🚀 Q3 2025
|
||||||
|
|
||||||
|
- [ ] Advanced caching strategies
|
||||||
|
- [ ] Add contextual-retrieval https://www.anthropic.com/news/contextual-retrieval
|
||||||
|
- [ ] Add sleep-time-compute and summarize agent! to summarilze the file on computer!
|
||||||
|
- [ ] Add OpenAI recompute API
|
||||||
|
|
||||||
|
## 🌟 Q4 2025
|
||||||
|
|
||||||
|
- [ ] Integration with LangChain/LlamaIndex
|
||||||
|
- [ ] Visual similarity search
|
||||||
|
- [ ] Query rewrtiting, rerank and expansion
|
||||||
88
examples/basic_demo.py
Normal file
88
examples/basic_demo.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""
|
||||||
|
Simple demo showing basic leann usage
|
||||||
|
Run: uv run python examples/basic_demo.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from leann import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Simple demo of Leann with selectable embedding models."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding_model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="The embedding model to use, e.g., 'sentence-transformers/all-mpnet-base-v2' or 'text-embedding-ada-002'.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"=== Leann Simple Demo with {args.embedding_model} ===")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Sample knowledge base
|
||||||
|
chunks = [
|
||||||
|
"Machine learning is a subset of artificial intelligence that enables computers to learn without being explicitly programmed.",
|
||||||
|
"Deep learning uses neural networks with multiple layers to process data and make decisions.",
|
||||||
|
"Natural language processing helps computers understand and generate human language.",
|
||||||
|
"Computer vision enables machines to interpret and understand visual information from images and videos.",
|
||||||
|
"Reinforcement learning teaches agents to make decisions by receiving rewards or penalties for their actions.",
|
||||||
|
"Data science combines statistics, programming, and domain expertise to extract insights from data.",
|
||||||
|
"Big data refers to extremely large datasets that require special tools and techniques to process.",
|
||||||
|
"Cloud computing provides on-demand access to computing resources over the internet.",
|
||||||
|
]
|
||||||
|
|
||||||
|
print("1. Building index (no embeddings stored)...")
|
||||||
|
builder = LeannBuilder(
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
backend_name="hnsw",
|
||||||
|
)
|
||||||
|
for chunk in chunks:
|
||||||
|
builder.add_text(chunk)
|
||||||
|
builder.build_index("demo_knowledge.leann")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("2. Searching with real-time embeddings...")
|
||||||
|
searcher = LeannSearcher("demo_knowledge.leann")
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"What is machine learning?",
|
||||||
|
"How does neural network work?",
|
||||||
|
"Tell me about data processing",
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in queries:
|
||||||
|
print(f"Query: {query}")
|
||||||
|
results = searcher.search(query, top_k=2)
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f" {i}. Score: {result.score:.3f}")
|
||||||
|
print(f" Text: {result.text[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("3. Interactive chat demo:")
|
||||||
|
print(" (Note: Requires OpenAI API key for real responses)")
|
||||||
|
|
||||||
|
chat = LeannChat("demo_knowledge.leann")
|
||||||
|
|
||||||
|
# Demo questions
|
||||||
|
demo_questions: list[str] = [
|
||||||
|
"What is the difference between machine learning and deep learning?",
|
||||||
|
"How is data science related to big data?",
|
||||||
|
]
|
||||||
|
|
||||||
|
for question in demo_questions:
|
||||||
|
print(f" Q: {question}")
|
||||||
|
response = chat.ask(question)
|
||||||
|
print(f" A: {response}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("Demo completed! Try running:")
|
||||||
|
print(" uv run python apps/document_rag.py")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
404
examples/dynamic_update_no_recompute.py
Normal file
404
examples/dynamic_update_no_recompute.py
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
"""Dynamic HNSW update demo without compact storage.
|
||||||
|
|
||||||
|
This script reproduces the minimal scenario we used while debugging on-the-fly
|
||||||
|
recompute:
|
||||||
|
|
||||||
|
1. Build a non-compact HNSW index from the first few paragraphs of a text file.
|
||||||
|
2. Print the top results with `recompute_embeddings=True`.
|
||||||
|
3. Append additional paragraphs with :meth:`LeannBuilder.update_index`.
|
||||||
|
4. Run the same query again to show the newly inserted passages.
|
||||||
|
|
||||||
|
Run it with ``uv`` (optionally pointing LEANN_HNSW_LOG_PATH at a file to inspect
|
||||||
|
ZMQ activity)::
|
||||||
|
|
||||||
|
LEANN_HNSW_LOG_PATH=embedding_fetch.log \
|
||||||
|
uv run -m examples.dynamic_update_no_recompute \
|
||||||
|
--index-path .leann/examples/leann-demo.leann
|
||||||
|
|
||||||
|
By default the script builds an index from ``data/2501.14312v1 (1).pdf`` and
|
||||||
|
then updates it with LEANN-related material from ``data/2506.08276v1.pdf``.
|
||||||
|
It issues the query "What's LEANN?" before and after the update to show how the
|
||||||
|
new passages become immediately searchable. The script uses the
|
||||||
|
``sentence-transformers/all-MiniLM-L6-v2`` model with ``is_recompute=True`` so
|
||||||
|
Faiss pulls existing vectors on demand via the ZMQ embedding server, while
|
||||||
|
freshly added passages are embedded locally just like the initial build.
|
||||||
|
|
||||||
|
To make storage comparisons easy, the script can also build a matching
|
||||||
|
``is_recompute=False`` baseline (enabled by default) and report the index size
|
||||||
|
delta after the update. Disable the baseline run with
|
||||||
|
``--skip-compare-no-recompute`` if you only need the recompute flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
DEFAULT_QUERY = "What's LEANN?"
|
||||||
|
DEFAULT_INITIAL_FILES = [REPO_ROOT / "data" / "2501.14312v1 (1).pdf"]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path]) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
return [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def run_search(index_path: Path, query: str, top_k: int, *, recompute_embeddings: bool) -> list:
|
||||||
|
searcher = LeannSearcher(str(index_path))
|
||||||
|
try:
|
||||||
|
return searcher.search(
|
||||||
|
query=query,
|
||||||
|
top_k=top_k,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(title: str, results: Iterable) -> None:
|
||||||
|
print(f"\n=== {title} ===")
|
||||||
|
res_list = list(results)
|
||||||
|
print(f"results count: {len(res_list)}")
|
||||||
|
print("passages:")
|
||||||
|
if not res_list:
|
||||||
|
print(" (no passages returned)")
|
||||||
|
for res in res_list:
|
||||||
|
snippet = res.text.replace("\n", " ")[:120]
|
||||||
|
print(f" - {res.id}: {snippet}... (score={res.score:.4f})")
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def update_index(
|
||||||
|
index_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
) -> None:
|
||||||
|
updater = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
for offset, passage in enumerate(paragraphs, start=start_id):
|
||||||
|
updater.add_text(passage, metadata={"id": str(offset)})
|
||||||
|
updater.update_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
"""Remove leftover index artifacts for a clean rebuild."""
|
||||||
|
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def index_file_size(index_path: Path) -> int:
|
||||||
|
"""Return the size of the primary .index file for the given index path."""
|
||||||
|
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
return index_file.stat().st_size if index_file.exists() else 0
|
||||||
|
|
||||||
|
|
||||||
|
def load_metadata_snapshot(index_path: Path) -> dict[str, Any] | None:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
if not meta_path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(meta_path.read_text())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run_workflow(
|
||||||
|
*,
|
||||||
|
label: str,
|
||||||
|
index_path: Path,
|
||||||
|
initial_paragraphs: list[str],
|
||||||
|
update_paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
is_recompute: bool,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
prefix = f"[{label}] " if label else ""
|
||||||
|
|
||||||
|
ensure_index_dir(index_path)
|
||||||
|
cleanup_index_files(index_path)
|
||||||
|
|
||||||
|
print(f"{prefix}Building initial index...")
|
||||||
|
build_initial_index(
|
||||||
|
index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
model_name,
|
||||||
|
embedding_mode,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_size = index_file_size(index_path)
|
||||||
|
before_results = run_search(
|
||||||
|
index_path,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{prefix}Updating index with additional passages...")
|
||||||
|
update_index(
|
||||||
|
index_path,
|
||||||
|
start_id=len(initial_paragraphs),
|
||||||
|
paragraphs=update_paragraphs,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
after_results = run_search(
|
||||||
|
index_path,
|
||||||
|
query,
|
||||||
|
top_k,
|
||||||
|
recompute_embeddings=is_recompute,
|
||||||
|
)
|
||||||
|
updated_size = index_file_size(index_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"initial_size": initial_size,
|
||||||
|
"updated_size": updated_size,
|
||||||
|
"delta": updated_size - initial_size,
|
||||||
|
"before_results": before_results,
|
||||||
|
"after_results": after_results,
|
||||||
|
"metadata": load_metadata_snapshot(index_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
type=Path,
|
||||||
|
nargs="+",
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Initial document files (PDF/TXT) used to build the base index",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/examples/leann-demo.leann"),
|
||||||
|
help="Destination index path (default: .leann/examples/leann-demo.leann)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-count",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of chunks to use from the initial documents (default: 8)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
type=Path,
|
||||||
|
nargs="*",
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Additional documents to add during update (PDF/TXT)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-count",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of chunks to append from update documents (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-text",
|
||||||
|
type=str,
|
||||||
|
default=(
|
||||||
|
"LEANN (Lightweight Embedding ANN) is an indexing toolkit focused on "
|
||||||
|
"recompute-aware HNSW graphs, allowing embeddings to be regenerated "
|
||||||
|
"on demand to keep disk usage minimal."
|
||||||
|
),
|
||||||
|
help="Fallback text to append if --update-files is omitted",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of results to show for each search (default: 4)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_QUERY,
|
||||||
|
help="Query to run before/after the update",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model name",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--compare-no-recompute",
|
||||||
|
dest="compare_no_recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Also run a baseline with is_recompute=False and report its index growth.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-compare-no-recompute",
|
||||||
|
dest="compare_no_recompute",
|
||||||
|
action="store_false",
|
||||||
|
help="Skip building the no-recompute baseline.",
|
||||||
|
)
|
||||||
|
parser.set_defaults(compare_no_recompute=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_chunks = load_chunks_from_files(list(args.initial_files))
|
||||||
|
if not initial_chunks:
|
||||||
|
raise ValueError("No text chunks extracted from the initial files.")
|
||||||
|
|
||||||
|
initial = initial_chunks[: args.initial_count]
|
||||||
|
if not initial:
|
||||||
|
raise ValueError("Initial chunk set is empty after applying --initial-count.")
|
||||||
|
|
||||||
|
if args.update_files:
|
||||||
|
update_chunks = load_chunks_from_files(list(args.update_files))
|
||||||
|
if not update_chunks:
|
||||||
|
raise ValueError("No text chunks extracted from the update files.")
|
||||||
|
to_add = update_chunks[: args.update_count]
|
||||||
|
else:
|
||||||
|
if not args.update_text:
|
||||||
|
raise ValueError("Provide --update-files or --update-text for the update step.")
|
||||||
|
to_add = [args.update_text]
|
||||||
|
if not to_add:
|
||||||
|
raise ValueError("Update chunk set is empty after applying --update-count.")
|
||||||
|
|
||||||
|
recompute_stats = run_workflow(
|
||||||
|
label="recompute",
|
||||||
|
index_path=args.index_path,
|
||||||
|
initial_paragraphs=initial,
|
||||||
|
update_paragraphs=to_add,
|
||||||
|
model_name=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
is_recompute=True,
|
||||||
|
query=args.query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_results("initial search", recompute_stats["before_results"])
|
||||||
|
print_results("after update", recompute_stats["after_results"])
|
||||||
|
print(
|
||||||
|
f"\n[recompute] Index file size change: {recompute_stats['initial_size']} -> {recompute_stats['updated_size']} bytes"
|
||||||
|
f" (Δ {recompute_stats['delta']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recompute_stats["metadata"]:
|
||||||
|
meta_view = {k: recompute_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||||
|
print("[recompute] metadata snapshot:")
|
||||||
|
print(json.dumps(meta_view, indent=2))
|
||||||
|
|
||||||
|
if args.compare_no_recompute:
|
||||||
|
baseline_path = (
|
||||||
|
args.index_path.parent / f"{args.index_path.stem}-norecompute{args.index_path.suffix}"
|
||||||
|
)
|
||||||
|
baseline_stats = run_workflow(
|
||||||
|
label="no-recompute",
|
||||||
|
index_path=baseline_path,
|
||||||
|
initial_paragraphs=initial,
|
||||||
|
update_paragraphs=to_add,
|
||||||
|
model_name=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
is_recompute=False,
|
||||||
|
query=args.query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n[no-recompute] Index file size change: {baseline_stats['initial_size']} -> {baseline_stats['updated_size']} bytes"
|
||||||
|
f" (Δ {baseline_stats['delta']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
after_texts = [res.text for res in recompute_stats["after_results"]]
|
||||||
|
baseline_after_texts = [res.text for res in baseline_stats["after_results"]]
|
||||||
|
if after_texts == baseline_after_texts:
|
||||||
|
print(
|
||||||
|
"[no-recompute] Search results match recompute baseline; see above for the shared output."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[no-recompute] WARNING: search results differ from recompute baseline.")
|
||||||
|
|
||||||
|
if baseline_stats["metadata"]:
|
||||||
|
meta_view = {k: baseline_stats["metadata"].get(k) for k in ("is_compact", "is_pruned")}
|
||||||
|
print("[no-recompute] metadata snapshot:")
|
||||||
|
print(json.dumps(meta_view, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
import os
|
|
||||||
import email
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
from llama_index.core import Document
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
|
|
||||||
def find_all_messages_directories(root: str = None) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Recursively find all 'Messages' directories under the given root.
|
|
||||||
Returns a list of Path objects.
|
|
||||||
"""
|
|
||||||
if root is None:
|
|
||||||
# Auto-detect user's mail path
|
|
||||||
home_dir = os.path.expanduser("~")
|
|
||||||
root = os.path.join(home_dir, "Library", "Mail")
|
|
||||||
|
|
||||||
messages_dirs = []
|
|
||||||
for dirpath, dirnames, filenames in os.walk(root):
|
|
||||||
if os.path.basename(dirpath) == "Messages":
|
|
||||||
messages_dirs.append(Path(dirpath))
|
|
||||||
return messages_dirs
|
|
||||||
|
|
||||||
class EmlxReader(BaseReader):
|
|
||||||
"""
|
|
||||||
Apple Mail .emlx file reader with embedded metadata.
|
|
||||||
|
|
||||||
Reads individual .emlx files from Apple Mail's storage format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, include_html: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
Initialize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
include_html: Whether to include HTML content in the email body (default: False)
|
|
||||||
"""
|
|
||||||
self.include_html = include_html
|
|
||||||
|
|
||||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> List[Document]:
|
|
||||||
"""
|
|
||||||
Load data from the input directory containing .emlx files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_dir: Directory containing .emlx files
|
|
||||||
**load_kwargs:
|
|
||||||
max_count (int): Maximum amount of messages to read.
|
|
||||||
"""
|
|
||||||
docs: List[Document] = []
|
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
# Walk through the directory recursively
|
|
||||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
|
||||||
# Skip hidden directories
|
|
||||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
|
||||||
|
|
||||||
for filename in filenames:
|
|
||||||
if count >= max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
if filename.endswith(".emlx"):
|
|
||||||
filepath = os.path.join(dirpath, filename)
|
|
||||||
try:
|
|
||||||
# Read the .emlx file
|
|
||||||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# .emlx files have a length prefix followed by the email content
|
|
||||||
# The first line contains the length, followed by the email
|
|
||||||
lines = content.split('\n', 1)
|
|
||||||
if len(lines) >= 2:
|
|
||||||
email_content = lines[1]
|
|
||||||
|
|
||||||
# Parse the email using Python's email module
|
|
||||||
try:
|
|
||||||
msg = email.message_from_string(email_content)
|
|
||||||
|
|
||||||
# Extract email metadata
|
|
||||||
subject = msg.get('Subject', 'No Subject')
|
|
||||||
from_addr = msg.get('From', 'Unknown')
|
|
||||||
to_addr = msg.get('To', 'Unknown')
|
|
||||||
date = msg.get('Date', 'Unknown')
|
|
||||||
|
|
||||||
# Extract email body
|
|
||||||
body = ""
|
|
||||||
if msg.is_multipart():
|
|
||||||
for part in msg.walk():
|
|
||||||
if part.get_content_type() == "text/plain" or part.get_content_type() == "text/html":
|
|
||||||
if part.get_content_type() == "text/html" and not self.include_html:
|
|
||||||
continue
|
|
||||||
body += part.get_payload(decode=True).decode('utf-8', errors='ignore')
|
|
||||||
# break
|
|
||||||
else:
|
|
||||||
body = msg.get_payload(decode=True).decode('utf-8', errors='ignore')
|
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
|
||||||
doc_content = f"""
|
|
||||||
[File]: {filename}
|
|
||||||
[From]: {from_addr}
|
|
||||||
[To]: {to_addr}
|
|
||||||
[Subject]: {subject}
|
|
||||||
[Date]: {date}
|
|
||||||
[EMAIL BODY Start]:
|
|
||||||
{body}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# No separate metadata - everything is in the text
|
|
||||||
doc = Document(text=doc_content, metadata={})
|
|
||||||
docs.append(doc)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error parsing email from {filepath}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading file {filepath}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"Loaded {len(docs)} email documents")
|
|
||||||
return docs
|
|
||||||
@@ -1,192 +0,0 @@
|
|||||||
"""
|
|
||||||
Mbox parser.
|
|
||||||
|
|
||||||
Contains simple parser for mbox files.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from fsspec import AbstractFileSystem
|
|
||||||
|
|
||||||
from llama_index.core.readers.base import BaseReader
|
|
||||||
from llama_index.core.schema import Document
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MboxReader(BaseReader):
|
|
||||||
"""
|
|
||||||
Mbox parser.
|
|
||||||
|
|
||||||
Extract messages from mailbox files.
|
|
||||||
Returns string including date, subject, sender, receiver and
|
|
||||||
content for each message.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_MESSAGE_FORMAT: str = (
|
|
||||||
"Date: {_date}\n"
|
|
||||||
"From: {_from}\n"
|
|
||||||
"To: {_to}\n"
|
|
||||||
"Subject: {_subject}\n"
|
|
||||||
"Content: {_content}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args: Any,
|
|
||||||
max_count: int = 0,
|
|
||||||
message_format: str = DEFAULT_MESSAGE_FORMAT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Init params."""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup # noqa
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.max_count = max_count
|
|
||||||
self.message_format = message_format
|
|
||||||
|
|
||||||
def load_data(
|
|
||||||
self,
|
|
||||||
file: Path,
|
|
||||||
extra_info: Optional[Dict] = None,
|
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Parse file into string."""
|
|
||||||
# Import required libraries
|
|
||||||
import mailbox
|
|
||||||
from email.parser import BytesParser
|
|
||||||
from email.policy import default
|
|
||||||
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
|
|
||||||
if fs:
|
|
||||||
logger.warning(
|
|
||||||
"fs was specified but MboxReader doesn't support loading "
|
|
||||||
"from fsspec filesystems. Will load from local filesystem instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
results: List[str] = []
|
|
||||||
# Load file using mailbox
|
|
||||||
bytes_parser = BytesParser(policy=default).parse
|
|
||||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
|
||||||
|
|
||||||
# Iterate through all messages
|
|
||||||
for _, _msg in enumerate(mbox):
|
|
||||||
try:
|
|
||||||
msg: mailbox.mboxMessage = _msg
|
|
||||||
# Parse multipart messages
|
|
||||||
if msg.is_multipart():
|
|
||||||
for part in msg.walk():
|
|
||||||
ctype = part.get_content_type()
|
|
||||||
cdispo = str(part.get("Content-Disposition"))
|
|
||||||
if "attachment" in cdispo:
|
|
||||||
print(f"Attachment found: {part.get_filename()}")
|
|
||||||
if ctype == "text/plain" and "attachment" not in cdispo:
|
|
||||||
content = part.get_payload(decode=True) # decode
|
|
||||||
break
|
|
||||||
# Get plain message payload for non-multipart messages
|
|
||||||
else:
|
|
||||||
content = msg.get_payload(decode=True)
|
|
||||||
|
|
||||||
# Parse message HTML content and remove unneeded whitespace
|
|
||||||
soup = BeautifulSoup(content)
|
|
||||||
stripped_content = " ".join(soup.get_text().split())
|
|
||||||
# Format message to include date, sender, receiver and subject
|
|
||||||
msg_string = self.message_format.format(
|
|
||||||
_date=msg["date"],
|
|
||||||
_from=msg["from"],
|
|
||||||
_to=msg["to"],
|
|
||||||
_subject=msg["subject"],
|
|
||||||
_content=stripped_content,
|
|
||||||
)
|
|
||||||
# Add message string to results
|
|
||||||
results.append(msg_string)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to parse message:\n{_msg}\n with exception {e}")
|
|
||||||
|
|
||||||
# Increment counter and return if max count is met
|
|
||||||
i += 1
|
|
||||||
if self.max_count > 0 and i >= self.max_count:
|
|
||||||
break
|
|
||||||
|
|
||||||
return [Document(text=result, metadata=extra_info or {}) for result in results]
|
|
||||||
|
|
||||||
|
|
||||||
class EmlxMboxReader(MboxReader):
|
|
||||||
"""
|
|
||||||
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
|
||||||
|
|
||||||
Extends MboxReader to work with Apple Mail's .emlx format by:
|
|
||||||
1. Reading .emlx files from a directory
|
|
||||||
2. Converting them to mbox format in memory
|
|
||||||
3. Using the parent MboxReader's parsing logic
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load_data(
|
|
||||||
self,
|
|
||||||
directory: Path,
|
|
||||||
extra_info: Optional[Dict] = None,
|
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
|
||||||
import tempfile
|
|
||||||
import os
|
|
||||||
|
|
||||||
if fs:
|
|
||||||
logger.warning(
|
|
||||||
"fs was specified but EmlxMboxReader doesn't support loading "
|
|
||||||
"from fsspec filesystems. Will load from local filesystem instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find all .emlx files in the directory
|
|
||||||
emlx_files = list(directory.glob("*.emlx"))
|
|
||||||
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
|
||||||
|
|
||||||
if not emlx_files:
|
|
||||||
logger.warning(f"No .emlx files found in {directory}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Create a temporary mbox file
|
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
|
||||||
temp_mbox_path = temp_mbox.name
|
|
||||||
|
|
||||||
# Convert .emlx files to mbox format
|
|
||||||
for emlx_file in emlx_files:
|
|
||||||
try:
|
|
||||||
# Read the .emlx file
|
|
||||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# .emlx format: first line is length, rest is email content
|
|
||||||
lines = content.split('\n', 1)
|
|
||||||
if len(lines) >= 2:
|
|
||||||
email_content = lines[1] # Skip the length line
|
|
||||||
|
|
||||||
# Write to mbox format (each message starts with "From " and ends with blank line)
|
|
||||||
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to process {emlx_file}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Close the temporary file so MboxReader can read it
|
|
||||||
temp_mbox.close()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use the parent MboxReader's logic to parse the mbox file
|
|
||||||
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
|
||||||
finally:
|
|
||||||
# Clean up temporary file
|
|
||||||
try:
|
|
||||||
os.unlink(temp_mbox_path)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import argparse
|
|
||||||
try:
|
|
||||||
import dotenv
|
|
||||||
dotenv.load_dotenv()
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
# python-dotenv is not installed; skip loading environment variables
|
|
||||||
dotenv = None
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Any
|
|
||||||
from leann.api import LeannBuilder, LeannSearcher, LeannChat
|
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
|
||||||
|
|
||||||
# dotenv.load_dotenv() # handled above if python-dotenv is available
|
|
||||||
|
|
||||||
# Default Chrome profile path
|
|
||||||
DEFAULT_CHROME_PROFILE = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
|
||||||
|
|
||||||
def create_leann_index_from_multiple_chrome_profiles(profile_dirs: List[Path], index_path: str = "chrome_history_index.leann", max_count: int = -1):
|
|
||||||
"""
|
|
||||||
Create LEANN index from multiple Chrome profile data sources.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
profile_dirs: List of Path objects pointing to Chrome profile directories
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of history entries to process per profile
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from multiple Chrome profile data sources...")
|
|
||||||
|
|
||||||
# Load documents using ChromeHistoryReader from history_data
|
|
||||||
from history_data.history import ChromeHistoryReader
|
|
||||||
reader = ChromeHistoryReader()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
all_documents = []
|
|
||||||
total_processed = 0
|
|
||||||
|
|
||||||
# Process each Chrome profile directory
|
|
||||||
for i, profile_dir in enumerate(profile_dirs):
|
|
||||||
print(f"\nProcessing Chrome profile {i+1}/{len(profile_dirs)}: {profile_dir}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
documents = reader.load_data(
|
|
||||||
chrome_profile_path=str(profile_dir),
|
|
||||||
max_count=max_count
|
|
||||||
)
|
|
||||||
if documents:
|
|
||||||
print(f"Loaded {len(documents)} history documents from {profile_dir}")
|
|
||||||
all_documents.extend(documents)
|
|
||||||
total_processed += len(documents)
|
|
||||||
|
|
||||||
# Check if we've reached the max count
|
|
||||||
if max_count > 0 and total_processed >= max_count:
|
|
||||||
print(f"Reached max count of {max_count} documents")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print(f"No documents loaded from {profile_dir}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing {profile_dir}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not all_documents:
|
|
||||||
print("No documents loaded from any source. Exiting.")
|
|
||||||
# highlight info that you need to close all chrome browser before running this script and high light the instruction!!
|
|
||||||
print("\033[91mYou need to close or quit all chrome browser before running this script\033[0m")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"\nTotal loaded {len(all_documents)} history documents from {len(profile_dirs)} profiles")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=128)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in all_documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
text = node.get_content()
|
|
||||||
# text = '[Title] ' + doc.metadata["title"] + '\n' + text
|
|
||||||
all_texts.append(text)
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1 # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
def create_leann_index(profile_path: str = None, index_path: str = "chrome_history_index.leann", max_count: int = 1000):
|
|
||||||
"""
|
|
||||||
Create LEANN index from Chrome history data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
profile_path: Path to the Chrome profile directory (optional, uses default if None)
|
|
||||||
index_path: Path to save the LEANN index
|
|
||||||
max_count: Maximum number of history entries to process
|
|
||||||
"""
|
|
||||||
print("Creating LEANN index from Chrome history data...")
|
|
||||||
INDEX_DIR = Path(index_path).parent
|
|
||||||
|
|
||||||
if not INDEX_DIR.exists():
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Load documents using ChromeHistoryReader from history_data
|
|
||||||
from history_data.history import ChromeHistoryReader
|
|
||||||
reader = ChromeHistoryReader()
|
|
||||||
|
|
||||||
documents = reader.load_data(
|
|
||||||
chrome_profile_path=profile_path,
|
|
||||||
max_count=max_count
|
|
||||||
)
|
|
||||||
|
|
||||||
if not documents:
|
|
||||||
print("No documents loaded. Exiting.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
print(f"Loaded {len(documents)} history documents")
|
|
||||||
|
|
||||||
# Create text splitter with 256 chunk size
|
|
||||||
text_splitter = SentenceSplitter(chunk_size=256, chunk_overlap=25)
|
|
||||||
|
|
||||||
# Convert Documents to text strings and chunk them
|
|
||||||
all_texts = []
|
|
||||||
for doc in documents:
|
|
||||||
# Split the document into chunks
|
|
||||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
|
||||||
for node in nodes:
|
|
||||||
all_texts.append(node.get_content())
|
|
||||||
|
|
||||||
print(f"Created {len(all_texts)} text chunks from {len(documents)} documents")
|
|
||||||
|
|
||||||
# Create LEANN index directory
|
|
||||||
print(f"--- Index directory not found, building new index ---")
|
|
||||||
INDEX_DIR.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
print(f"--- Building new LEANN index ---")
|
|
||||||
|
|
||||||
print(f"\n[PHASE 1] Building Leann index...")
|
|
||||||
|
|
||||||
# Use HNSW backend for better macOS compatibility
|
|
||||||
builder = LeannBuilder(
|
|
||||||
backend_name="hnsw",
|
|
||||||
embedding_model="facebook/contriever",
|
|
||||||
graph_degree=32,
|
|
||||||
complexity=64,
|
|
||||||
is_compact=True,
|
|
||||||
is_recompute=True,
|
|
||||||
num_threads=1 # Force single-threaded mode
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Adding {len(all_texts)} history chunks to index...")
|
|
||||||
for chunk_text in all_texts:
|
|
||||||
builder.add_text(chunk_text)
|
|
||||||
|
|
||||||
builder.build_index(index_path)
|
|
||||||
print(f"\nLEANN index built at {index_path}!")
|
|
||||||
else:
|
|
||||||
print(f"--- Using existing index at {INDEX_DIR} ---")
|
|
||||||
|
|
||||||
return index_path
|
|
||||||
|
|
||||||
async def query_leann_index(index_path: str, query: str):
|
|
||||||
"""
|
|
||||||
Query the LEANN index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_path: Path to the LEANN index
|
|
||||||
query: The query string
|
|
||||||
"""
|
|
||||||
print(f"\n[PHASE 2] Starting Leann chat session...")
|
|
||||||
chat = LeannChat(index_path=index_path)
|
|
||||||
|
|
||||||
print(f"You: {query}")
|
|
||||||
chat_response = chat.ask(
|
|
||||||
query,
|
|
||||||
top_k=10,
|
|
||||||
recompute_beighbor_embeddings=True,
|
|
||||||
complexity=32,
|
|
||||||
beam_width=1,
|
|
||||||
llm_config={
|
|
||||||
"type": "openai",
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
llm_kwargs={
|
|
||||||
"temperature": 0.0,
|
|
||||||
"max_tokens": 1000
|
|
||||||
}
|
|
||||||
)
|
|
||||||
print(f"Leann: {chat_response}")
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# Parse command line arguments
|
|
||||||
parser = argparse.ArgumentParser(description='LEANN Chrome History Reader - Create and query browser history index')
|
|
||||||
parser.add_argument('--chrome-profile', type=str, default=DEFAULT_CHROME_PROFILE,
|
|
||||||
help=f'Path to Chrome profile directory (default: {DEFAULT_CHROME_PROFILE}), usually you dont need to change this')
|
|
||||||
parser.add_argument('--index-dir', type=str, default="./all_google_new",
|
|
||||||
help='Directory to store the LEANN index (default: ./chrome_history_index_leann_test)')
|
|
||||||
parser.add_argument('--max-entries', type=int, default=1000,
|
|
||||||
help='Maximum number of history entries to process (default: 1000)')
|
|
||||||
parser.add_argument('--query', type=str, default=None,
|
|
||||||
help='Single query to run (default: runs example queries)')
|
|
||||||
parser.add_argument('--auto-find-profiles', action='store_true', default=True,
|
|
||||||
help='Automatically find all Chrome profiles (default: True)')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
INDEX_DIR = Path(args.index_dir)
|
|
||||||
INDEX_PATH = str(INDEX_DIR / "chrome_history.leann")
|
|
||||||
|
|
||||||
print(f"Using Chrome profile: {args.chrome_profile}")
|
|
||||||
print(f"Index directory: {INDEX_DIR}")
|
|
||||||
print(f"Max entries: {args.max_entries}")
|
|
||||||
|
|
||||||
# Find Chrome profile directories
|
|
||||||
from history_data.history import ChromeHistoryReader
|
|
||||||
|
|
||||||
if args.auto_find_profiles:
|
|
||||||
profile_dirs = ChromeHistoryReader.find_chrome_profiles()
|
|
||||||
if not profile_dirs:
|
|
||||||
print("No Chrome profiles found automatically. Exiting.")
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# Use single specified profile
|
|
||||||
profile_path = Path(args.chrome_profile)
|
|
||||||
if not profile_path.exists():
|
|
||||||
print(f"Chrome profile not found: {profile_path}")
|
|
||||||
return
|
|
||||||
profile_dirs = [profile_path]
|
|
||||||
|
|
||||||
# Create or load the LEANN index from all sources
|
|
||||||
index_path = create_leann_index_from_multiple_chrome_profiles(profile_dirs, INDEX_PATH, args.max_entries)
|
|
||||||
|
|
||||||
if index_path:
|
|
||||||
if args.query:
|
|
||||||
# Run single query
|
|
||||||
await query_leann_index(index_path, args.query)
|
|
||||||
else:
|
|
||||||
# Example queries
|
|
||||||
queries = [
|
|
||||||
"What websites did I visit about machine learning?",
|
|
||||||
"Find my search history about programming"
|
|
||||||
]
|
|
||||||
|
|
||||||
for query in queries:
|
|
||||||
print("\n" + "="*60)
|
|
||||||
await query_leann_index(index_path, query)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
35
examples/grep_search_example.py
Normal file
35
examples/grep_search_example.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Grep Search Example
|
||||||
|
|
||||||
|
Shows how to use grep-based text search instead of semantic search.
|
||||||
|
Useful when you need exact text matches rather than meaning-based results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Load your index
|
||||||
|
searcher = LeannSearcher("my-documents.leann")
|
||||||
|
|
||||||
|
# Regular semantic search
|
||||||
|
print("=== Semantic Search ===")
|
||||||
|
results = searcher.search("machine learning algorithms", top_k=3)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score:.3f}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Grep-based search for exact text matches
|
||||||
|
print("=== Grep Search ===")
|
||||||
|
results = searcher.search("def train_model", top_k=3, use_grep=True)
|
||||||
|
for result in results:
|
||||||
|
print(f"Score: {result.score}")
|
||||||
|
print(f"Text: {result.text[:80]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Find specific error messages
|
||||||
|
error_results = searcher.search("FileNotFoundError", use_grep=True)
|
||||||
|
print(f"Found {len(error_results)} files mentioning FileNotFoundError")
|
||||||
|
|
||||||
|
# Search for function definitions
|
||||||
|
func_results = searcher.search("class SearchResult", use_grep=True, top_k=5)
|
||||||
|
print(f"Found {len(func_results)} class definitions")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user