Compare commits
242 Commits
v0.1.1
...
feature/cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9996c29618 | ||
|
|
12951ad4d5 | ||
|
|
a878d2459b | ||
|
|
6c39a3427f | ||
|
|
17cbd07b25 | ||
|
|
3629ccf8f7 | ||
|
|
a0bbf831db | ||
|
|
76cc798e3e | ||
|
|
d599566fd7 | ||
|
|
00770aebbb | ||
|
|
e268392d5b | ||
|
|
eb909ccec5 | ||
|
|
969f514564 | ||
|
|
1ef9cba7de | ||
|
|
a63550944b | ||
|
|
97493a2896 | ||
|
|
f7d2dc6e7c | ||
|
|
ea86b283cb | ||
|
|
e7519bceaa | ||
|
|
abf0b2c676 | ||
|
|
3c4785bb63 | ||
|
|
930b79cc98 | ||
|
|
3766ad1fd2 | ||
|
|
c3aceed1e0 | ||
|
|
dc6c9f696e | ||
|
|
2406c41eef | ||
|
|
d4f5f2896f | ||
|
|
366984e92e | ||
|
|
64b92a04a7 | ||
|
|
a85d0ad4a7 | ||
|
|
dbb5f4d352 | ||
|
|
f180b83589 | ||
|
|
abf312d998 | ||
|
|
ab251ab751 | ||
|
|
28085f6f04 | ||
|
|
6495833887 | ||
|
|
5543b3c5f7 | ||
|
|
a99983b3d9 | ||
|
|
36482e016c | ||
|
|
32967daf81 | ||
|
|
b4bb8dec75 | ||
|
|
5ba9cf6442 | ||
|
|
1484406a8d | ||
|
|
761ec1f0ac | ||
|
|
4808afc686 | ||
|
|
0bba4b2157 | ||
|
|
e67b5f44fa | ||
|
|
658bce47ef | ||
|
|
6b399ad8d2 | ||
|
|
16f35aa067 | ||
|
|
ab9c6bd69e | ||
|
|
e2b37914ce | ||
|
|
e588100674 | ||
|
|
fecee94af1 | ||
|
|
01475c10a0 | ||
|
|
c8aa063f48 | ||
|
|
576beb13db | ||
|
|
63c7b0c8a3 | ||
|
|
ec889f7ef4 | ||
|
|
322e5c162d | ||
|
|
edde0cdeb2 | ||
|
|
db7ba27ff6 | ||
|
|
5f7806e16f | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
4e5b73ce7b | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
a6f07a54f1 | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 | ||
|
|
ef980d70b3 | ||
|
|
db3c63c441 | ||
|
|
00eeadb9dd | ||
|
|
42c8370709 | ||
|
|
fafdf8fcbe | ||
|
|
21f7d8e031 | ||
|
|
46565b9249 | ||
|
|
3dad76126a | ||
|
|
18e28bda32 | ||
|
|
609fa62fd5 | ||
|
|
eab13434ef | ||
|
|
b2390ccc14 | ||
|
|
e8fca2c84a | ||
|
|
790ae14f69 | ||
|
|
ac363072e6 | ||
|
|
93465af46c | ||
|
|
792ece67dc | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
3ff5aac8e0 | ||
|
|
67fef60466 | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
5567302316 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c96d653072 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 | ||
|
|
cf2ef48967 | ||
|
|
0692bbf7a2 | ||
|
|
52584a171f | ||
|
|
efd6b5324b | ||
|
|
2baaa4549b | ||
|
|
35310ddd52 | ||
|
|
fc9c5cb39d | ||
|
|
8f2a1e87ea | ||
|
|
50caf65f28 | ||
|
|
1b48794ca8 | ||
|
|
4aef1d814e | ||
|
|
75ddcd6158 | ||
|
|
2a4df11f5c | ||
|
|
5eb893c62b | ||
|
|
d91ce2e94d | ||
|
|
5c2ff8a641 | ||
|
|
d4f474c9b7 | ||
|
|
170f7644e9 | ||
|
|
cd8b970eff | ||
|
|
52153bbb69 | ||
|
|
e1ae087207 | ||
|
|
48c5e12ac1 | ||
|
|
f8b5c97190 | ||
|
|
d038c81b8b | ||
|
|
29cbbbd0d6 | ||
|
|
179f30bc36 | ||
|
|
c4a0a68581 | ||
|
|
5c836ad08e | ||
|
|
673fd9b7cd | ||
|
|
84b24b233d | ||
|
|
499cdd7822 | ||
|
|
800d4cf111 | ||
|
|
b6d43f5fd9 | ||
|
|
3603cd5034 | ||
|
|
6df7893173 | ||
|
|
e64b599276 | ||
|
|
2dd59c4ba1 | ||
|
|
166986d5e6 | ||
|
|
a6aec68f32 | ||
|
|
ed27a127d5 | ||
|
|
d8b4ea7564 | ||
|
|
f0a2ef96b4 | ||
|
|
7d73c2c803 | ||
|
|
e8d2ecab03 | ||
|
|
32a374d094 | ||
|
|
d45c013806 | ||
|
|
9000a7083d | ||
|
|
8307555d54 | ||
|
|
20f2aece08 | ||
|
|
43eb4f9a1d | ||
|
|
5461b71d8c | ||
|
|
374db0ebb8 | ||
|
|
cea1f6f87c | ||
|
|
6c0e39372b | ||
|
|
2bec67d2b6 | ||
|
|
133e715832 | ||
|
|
95cf2f16e2 | ||
|
|
47a4c153eb |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
|||||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
name: Bug Report
|
||||||
|
description: Report a bug in LEANN
|
||||||
|
labels: ["bug"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: What happened?
|
||||||
|
description: A clear description of the bug
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: reproduce
|
||||||
|
attributes:
|
||||||
|
label: How to reproduce
|
||||||
|
placeholder: |
|
||||||
|
1. Install with...
|
||||||
|
2. Run command...
|
||||||
|
3. See error
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: error
|
||||||
|
attributes:
|
||||||
|
label: Error message
|
||||||
|
description: Paste any error messages
|
||||||
|
render: shell
|
||||||
|
|
||||||
|
- type: input
|
||||||
|
id: version
|
||||||
|
attributes:
|
||||||
|
label: LEANN Version
|
||||||
|
placeholder: "0.1.0"
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating System
|
||||||
|
options:
|
||||||
|
- macOS
|
||||||
|
- Linux
|
||||||
|
- Windows
|
||||||
|
- Docker
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Documentation
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||||
|
about: Read the docs first
|
||||||
|
- name: Discussions
|
||||||
|
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||||
|
about: Ask questions and share ideas
|
||||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Feature Request
|
||||||
|
description: Suggest a new feature for LEANN
|
||||||
|
labels: ["enhancement"]
|
||||||
|
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: problem
|
||||||
|
attributes:
|
||||||
|
label: What problem does this solve?
|
||||||
|
description: Describe the problem or need
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: solution
|
||||||
|
attributes:
|
||||||
|
label: Proposed solution
|
||||||
|
description: How would you like this to work?
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: example
|
||||||
|
attributes:
|
||||||
|
label: Example usage
|
||||||
|
description: Show how the API might look
|
||||||
|
render: python
|
||||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
## What does this PR do?
|
||||||
|
|
||||||
|
<!-- Brief description of your changes -->
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
Fixes #
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [ ] Tests pass (`uv run pytest`)
|
||||||
|
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||||
|
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||||
256
.github/workflows/build-and-publish.yml
vendored
256
.github/workflows/build-and-publish.yml
vendored
@@ -1,256 +1,12 @@
|
|||||||
name: Build and Publish to PyPI
|
name: CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
release:
|
|
||||||
types: [published]
|
|
||||||
push:
|
push:
|
||||||
tags:
|
branches: [ main ]
|
||||||
- 'v*'
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
|
||||||
publish:
|
|
||||||
description: 'Publish to PyPI'
|
|
||||||
required: true
|
|
||||||
default: 'false'
|
|
||||||
type: choice
|
|
||||||
options:
|
|
||||||
- 'false'
|
|
||||||
- 'test'
|
|
||||||
- 'prod'
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# Build pure Python package: leann-core
|
build:
|
||||||
build-core:
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
name: Build leann-core
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install build dependencies
|
|
||||||
run: |
|
|
||||||
uv pip install --system build twine
|
|
||||||
|
|
||||||
- name: Build package
|
|
||||||
run: |
|
|
||||||
cd packages/leann-core
|
|
||||||
uv build
|
|
||||||
|
|
||||||
- name: Upload artifacts
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: leann-core-dist
|
|
||||||
path: packages/leann-core/dist/
|
|
||||||
|
|
||||||
# Build binary package: leann-backend-hnsw (default backend)
|
|
||||||
build-hnsw:
|
|
||||||
name: Build leann-backend-hnsw
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, macos-latest]
|
|
||||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: recursive
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install system dependencies (Ubuntu)
|
|
||||||
if: runner.os == 'Linux'
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y libomp-dev libboost-all-dev libzmq3-dev \
|
|
||||||
pkg-config libopenblas-dev patchelf
|
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
brew install libomp boost zeromq
|
|
||||||
|
|
||||||
- name: Install build dependencies
|
|
||||||
run: |
|
|
||||||
uv pip install --system scikit-build-core numpy swig
|
|
||||||
uv pip install --system auditwheel delocate
|
|
||||||
|
|
||||||
- name: Build wheel
|
|
||||||
run: |
|
|
||||||
cd packages/leann-backend-hnsw
|
|
||||||
uv build --wheel --python python
|
|
||||||
|
|
||||||
- name: Repair wheel (Linux)
|
|
||||||
if: runner.os == 'Linux'
|
|
||||||
run: |
|
|
||||||
cd packages/leann-backend-hnsw
|
|
||||||
auditwheel repair dist/*.whl -w dist_repaired
|
|
||||||
rm -rf dist
|
|
||||||
mv dist_repaired dist
|
|
||||||
|
|
||||||
- name: Repair wheel (macOS)
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
cd packages/leann-backend-hnsw
|
|
||||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
|
||||||
rm -rf dist
|
|
||||||
mv dist_repaired dist
|
|
||||||
|
|
||||||
- name: Upload artifacts
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: hnsw-${{ matrix.os }}-py${{ matrix.python-version }}
|
|
||||||
path: packages/leann-backend-hnsw/dist/
|
|
||||||
|
|
||||||
# Build binary package: leann-backend-diskann (multi-platform)
|
|
||||||
build-diskann:
|
|
||||||
name: Build leann-backend-diskann
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, macos-latest]
|
|
||||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: recursive
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install system dependencies (Ubuntu)
|
|
||||||
if: runner.os == 'Linux'
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y libomp-dev libboost-all-dev libaio-dev libzmq3-dev \
|
|
||||||
protobuf-compiler libprotobuf-dev libabsl-dev patchelf
|
|
||||||
|
|
||||||
# Install Intel MKL using Intel's installer
|
|
||||||
wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
|
||||||
source /opt/intel/oneapi/setvars.sh
|
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
brew install libomp boost zeromq protobuf
|
|
||||||
# MKL is not available on Homebrew, but DiskANN can work without it
|
|
||||||
|
|
||||||
- name: Install build dependencies
|
|
||||||
run: |
|
|
||||||
uv pip install --system scikit-build-core numpy Cython pybind11
|
|
||||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
|
||||||
uv pip install --system auditwheel
|
|
||||||
else
|
|
||||||
uv pip install --system delocate
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Build wheel
|
|
||||||
run: |
|
|
||||||
cd packages/leann-backend-diskann
|
|
||||||
uv build --wheel --python python
|
|
||||||
|
|
||||||
- name: Repair wheel (Linux)
|
|
||||||
if: runner.os == 'Linux'
|
|
||||||
run: |
|
|
||||||
cd packages/leann-backend-diskann
|
|
||||||
auditwheel repair dist/*.whl -w dist_repaired
|
|
||||||
rm -rf dist
|
|
||||||
mv dist_repaired dist
|
|
||||||
|
|
||||||
- name: Repair wheel (macOS)
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
cd packages/leann-backend-diskann
|
|
||||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
|
||||||
rm -rf dist
|
|
||||||
mv dist_repaired dist
|
|
||||||
|
|
||||||
- name: Upload artifacts
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: diskann-${{ matrix.os }}-py${{ matrix.python-version }}
|
|
||||||
path: packages/leann-backend-diskann/dist/
|
|
||||||
|
|
||||||
# Build meta-package: leann (build last)
|
|
||||||
build-meta:
|
|
||||||
name: Build leann meta-package
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v4
|
|
||||||
|
|
||||||
- name: Install build dependencies
|
|
||||||
run: |
|
|
||||||
uv pip install --system build
|
|
||||||
|
|
||||||
- name: Build package
|
|
||||||
run: |
|
|
||||||
cd packages/leann
|
|
||||||
uv build
|
|
||||||
|
|
||||||
- name: Upload artifacts
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: leann-meta-dist
|
|
||||||
path: packages/leann/dist/
|
|
||||||
|
|
||||||
# Publish to PyPI
|
|
||||||
publish:
|
|
||||||
name: Publish to PyPI
|
|
||||||
needs: [build-core, build-hnsw, build-diskann, build-meta]
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: github.event_name == 'release' || github.event.inputs.publish != 'false'
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Download all artifacts
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
path: dist
|
|
||||||
|
|
||||||
- name: Flatten directory structure
|
|
||||||
run: |
|
|
||||||
mkdir -p all_wheels
|
|
||||||
find dist -name "*.whl" -exec cp {} all_wheels/ \;
|
|
||||||
find dist -name "*.tar.gz" -exec cp {} all_wheels/ \;
|
|
||||||
|
|
||||||
- name: Publish to Test PyPI
|
|
||||||
if: github.event.inputs.publish == 'test' || github.event_name == 'workflow_dispatch'
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
with:
|
|
||||||
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
|
||||||
repository-url: https://test.pypi.org/legacy/
|
|
||||||
packages-dir: all_wheels/
|
|
||||||
|
|
||||||
- name: Publish to PyPI
|
|
||||||
if: github.event_name == 'release' || github.event.inputs.publish == 'prod'
|
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
|
||||||
with:
|
|
||||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
|
||||||
packages-dir: all_wheels/
|
|
||||||
|
|||||||
451
.github/workflows/build-reusable.yml
vendored
Normal file
451
.github/workflows/build-reusable.yml
vendored
Normal file
@@ -0,0 +1,451 @@
|
|||||||
|
name: Reusable Build
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_call:
|
||||||
|
inputs:
|
||||||
|
ref:
|
||||||
|
description: 'Git ref to build'
|
||||||
|
required: false
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint and Format Check
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install uv and Python
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Run pre-commit with only lint group (no project deps)
|
||||||
|
run: |
|
||||||
|
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||||
|
|
||||||
|
|
||||||
|
build:
|
||||||
|
needs: lint
|
||||||
|
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-22.04
|
||||||
|
python: '3.13'
|
||||||
|
# ARM64 Linux builds
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.9'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.10'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.11'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.12'
|
||||||
|
- os: ubuntu-24.04-arm
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-14
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.12'
|
||||||
|
- os: macos-15
|
||||||
|
python: '3.13'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.9'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.10'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.11'
|
||||||
|
- os: macos-13
|
||||||
|
python: '3.12'
|
||||||
|
# Note: macos-13 + Python 3.13 excluded due to PyTorch compatibility
|
||||||
|
# (PyTorch 2.5+ supports Python 3.13 but not Intel Mac x86_64)
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
ref: ${{ inputs.ref }}
|
||||||
|
submodules: recursive
|
||||||
|
|
||||||
|
- name: Install uv and Python
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
- name: Install system dependencies (Ubuntu)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||||
|
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||||
|
patchelf
|
||||||
|
|
||||||
|
# Debug: Show system information
|
||||||
|
echo "🔍 System Information:"
|
||||||
|
echo "Architecture: $(uname -m)"
|
||||||
|
echo "OS: $(uname -a)"
|
||||||
|
echo "CPU info: $(lscpu | head -5)"
|
||||||
|
|
||||||
|
# Install math library based on architecture
|
||||||
|
ARCH=$(uname -m)
|
||||||
|
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||||
|
|
||||||
|
if [[ "$ARCH" == "x86_64" ]]; then
|
||||||
|
# Install Intel MKL for DiskANN on x86_64
|
||||||
|
echo "📦 Installing Intel MKL for x86_64..."
|
||||||
|
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||||
|
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||||
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||||
|
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||||
|
echo "✅ Intel MKL installed for x86_64"
|
||||||
|
|
||||||
|
# Debug: Check MKL installation
|
||||||
|
echo "🔍 MKL Installation Check:"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||||
|
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||||
|
|
||||||
|
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||||
|
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||||
|
echo "📦 Installing OpenBLAS for ARM64..."
|
||||||
|
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||||
|
echo "✅ OpenBLAS installed for ARM64"
|
||||||
|
|
||||||
|
# Debug: Check OpenBLAS installation
|
||||||
|
echo "🔍 OpenBLAS Installation Check:"
|
||||||
|
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||||
|
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Debug: Show final library paths
|
||||||
|
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||||
|
|
||||||
|
- name: Install system dependencies (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Don't install LLVM, use system clang for better compatibility
|
||||||
|
brew install libomp boost protobuf zeromq
|
||||||
|
|
||||||
|
- name: Install build dependencies
|
||||||
|
run: |
|
||||||
|
uv python install ${{ matrix.python }}
|
||||||
|
uv venv --python ${{ matrix.python }} .uv-build
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
BUILD_PY=".uv-build\\Scripts\\python.exe"
|
||||||
|
else
|
||||||
|
BUILD_PY=".uv-build/bin/python"
|
||||||
|
fi
|
||||||
|
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
|
||||||
|
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||||
|
uv pip install --python "$BUILD_PY" auditwheel
|
||||||
|
else
|
||||||
|
uv pip install --python "$BUILD_PY" delocate
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
|
||||||
|
else
|
||||||
|
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Set macOS environment variables
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Use brew --prefix to automatically detect Homebrew installation path
|
||||||
|
HOMEBREW_PREFIX=$(brew --prefix)
|
||||||
|
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||||
|
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
||||||
|
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
# Set compiler flags for OpenMP (required for both backends)
|
||||||
|
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
||||||
|
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Build packages
|
||||||
|
run: |
|
||||||
|
# Build core (platform independent)
|
||||||
|
cd packages/leann-core
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build HNSW backend
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||||
|
# Use system clang for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# Homebrew libraries on each macOS version require matching minimum version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
else
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build DiskANN backend
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||||
|
# Use system clang for better compatibility
|
||||||
|
export CC=clang
|
||||||
|
export CXX=clang++
|
||||||
|
# DiskANN requires macOS 13.3+ for sgesdd_ LAPACK function
|
||||||
|
# But Homebrew libraries on each macOS version require matching minimum version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
else
|
||||||
|
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Build meta package (platform independent)
|
||||||
|
cd packages/leann
|
||||||
|
uv build
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (Linux)
|
||||||
|
if: runner.os == 'Linux'
|
||||||
|
run: |
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
auditwheel repair dist/*.whl -w dist_repaired
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: Repair wheels (macOS)
|
||||||
|
if: runner.os == 'macOS'
|
||||||
|
run: |
|
||||||
|
# Determine deployment target based on runner OS
|
||||||
|
# Must match the Homebrew libraries for each macOS version
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
HNSW_TARGET="13.0"
|
||||||
|
DISKANN_TARGET="13.3"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
HNSW_TARGET="14.0"
|
||||||
|
DISKANN_TARGET="14.0"
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
HNSW_TARGET="15.0"
|
||||||
|
DISKANN_TARGET="15.0"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Repair HNSW wheel
|
||||||
|
cd packages/leann-backend-hnsw
|
||||||
|
if [ -d dist ]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
# Repair DiskANN wheel
|
||||||
|
cd packages/leann-backend-diskann
|
||||||
|
if [ -d dist ]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
||||||
|
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
||||||
|
rm -rf dist
|
||||||
|
mv dist_repaired dist
|
||||||
|
fi
|
||||||
|
cd ../..
|
||||||
|
|
||||||
|
- name: List built packages
|
||||||
|
run: |
|
||||||
|
echo "📦 Built packages:"
|
||||||
|
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||||
|
|
||||||
|
|
||||||
|
- name: Install built packages for testing
|
||||||
|
run: |
|
||||||
|
# Create uv-managed virtual environment with the requested interpreter
|
||||||
|
uv python install ${{ matrix.python }}
|
||||||
|
uv venv --python ${{ matrix.python }}
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||||
|
UV_PY=".venv\\Scripts\\python.exe"
|
||||||
|
else
|
||||||
|
UV_PY=".venv/bin/python"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install test dependency group only (avoids reinstalling project package)
|
||||||
|
uv pip install --python "$UV_PY" --group test
|
||||||
|
|
||||||
|
# Install core wheel built in this job
|
||||||
|
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||||
|
if [[ -n "$CORE_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$CORE_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||||
|
|
||||||
|
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||||
|
if [[ "${{ matrix.os }}" == "macos-13" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=13.3
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-14" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||||
|
elif [[ "${{ matrix.os }}" == "macos-15" ]]; then
|
||||||
|
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||||
|
if [[ -z "$HNSW_WHL" ]]; then
|
||||||
|
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||||
|
fi
|
||||||
|
if [[ -n "$HNSW_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$HNSW_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
|
||||||
|
fi
|
||||||
|
|
||||||
|
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||||
|
if [[ -z "$DISKANN_WHL" ]]; then
|
||||||
|
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||||
|
fi
|
||||||
|
if [[ -n "$DISKANN_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$DISKANN_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
|
||||||
|
fi
|
||||||
|
|
||||||
|
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||||
|
if [[ -n "$LEANN_WHL" ]]; then
|
||||||
|
uv pip install --python "$UV_PY" "$LEANN_WHL"
|
||||||
|
else
|
||||||
|
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Run tests with pytest
|
||||||
|
env:
|
||||||
|
CI: true
|
||||||
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
HF_HUB_DISABLE_SYMLINKS: 1
|
||||||
|
TOKENIZERS_PARALLELISM: false
|
||||||
|
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
pytest tests/ -v --tb=short
|
||||||
|
|
||||||
|
- name: Run sanity checks (optional)
|
||||||
|
run: |
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
|
||||||
|
# Run distance function tests if available
|
||||||
|
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||||
|
echo "Running distance function sanity checks..."
|
||||||
|
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||||
|
path: packages/*/dist/
|
||||||
|
|
||||||
|
|
||||||
|
arch-smoke:
|
||||||
|
name: Arch Linux smoke test (install & import)
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: archlinux:latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Prepare system
|
||||||
|
run: |
|
||||||
|
pacman -Syu --noconfirm
|
||||||
|
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||||
|
|
||||||
|
- name: Download ALL wheel artifacts from this run
|
||||||
|
uses: actions/download-artifact@v5
|
||||||
|
with:
|
||||||
|
# Don't specify name, download all artifacts
|
||||||
|
path: ./wheels
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
|
||||||
|
- name: Create virtual environment and install wheels
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
uv pip install --find-links wheels leann-core
|
||||||
|
uv pip install --find-links wheels leann-backend-hnsw
|
||||||
|
uv pip install --find-links wheels leann-backend-diskann
|
||||||
|
uv pip install --find-links wheels leann
|
||||||
|
|
||||||
|
- name: Import & tiny runtime check
|
||||||
|
env:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
MKL_NUM_THREADS: 1
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate || source .venv/Scripts/activate
|
||||||
|
python - <<'PY'
|
||||||
|
import leann
|
||||||
|
import leann_backend_hnsw as h
|
||||||
|
import leann_backend_diskann as d
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
b = LeannBuilder(backend_name="hnsw")
|
||||||
|
b.add_text("hello arch")
|
||||||
|
b.build_index("arch_demo.leann")
|
||||||
|
s = LeannSearcher("arch_demo.leann")
|
||||||
|
print("search:", s.search("hello", top_k=1))
|
||||||
|
PY
|
||||||
110
.github/workflows/ci.yml
vendored
110
.github/workflows/ci.yml
vendored
@@ -1,110 +0,0 @@
|
|||||||
name: CI - Build and Test
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: [ main ]
|
|
||||||
pull_request:
|
|
||||||
branches: [ main ]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build-test:
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
os: [ubuntu-latest, macos-latest]
|
|
||||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
submodules: recursive
|
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python-version }}
|
|
||||||
|
|
||||||
- name: Install uv
|
|
||||||
run: |
|
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
||||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
|
||||||
|
|
||||||
- name: Install system dependencies (Ubuntu)
|
|
||||||
if: runner.os == 'Linux'
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y libomp-dev libboost-all-dev libzmq3-dev \
|
|
||||||
pkg-config libopenblas-dev patchelf \
|
|
||||||
libaio-dev protobuf-compiler libprotobuf-dev libabsl-dev
|
|
||||||
|
|
||||||
# Install Intel MKL for DiskANN
|
|
||||||
wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
|
||||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
|
||||||
source /opt/intel/oneapi/setvars.sh
|
|
||||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
|
||||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Install system dependencies (macOS)
|
|
||||||
if: runner.os == 'macOS'
|
|
||||||
run: |
|
|
||||||
brew install libomp boost zeromq protobuf
|
|
||||||
|
|
||||||
- name: Build all packages
|
|
||||||
run: |
|
|
||||||
echo "🔨 Building on ${{ matrix.os }} with Python ${{ matrix.python-version }}..."
|
|
||||||
export UV_SYSTEM_PYTHON=1
|
|
||||||
|
|
||||||
# Verify Python version
|
|
||||||
python --version
|
|
||||||
which python
|
|
||||||
|
|
||||||
# Build each package
|
|
||||||
for pkg in leann-core leann-backend-hnsw leann-backend-diskann leann; do
|
|
||||||
echo "Building $pkg..."
|
|
||||||
cd packages/$pkg
|
|
||||||
rm -rf dist/ build/ _skbuild/
|
|
||||||
# Use explicit python interpreter
|
|
||||||
uv build --wheel --python python
|
|
||||||
if [ ! -f dist/*.whl ]; then
|
|
||||||
echo "❌ Failed to build $pkg!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "✅ $pkg built successfully"
|
|
||||||
cd ../..
|
|
||||||
done
|
|
||||||
|
|
||||||
- name: Install and test packages
|
|
||||||
run: |
|
|
||||||
# Create clean test environment
|
|
||||||
python -m venv test_env
|
|
||||||
if [[ "$OSTYPE" == "msys" || "$OSTYPE" == "win32" ]]; then
|
|
||||||
source test_env/Scripts/activate
|
|
||||||
else
|
|
||||||
source test_env/bin/activate
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Install built packages
|
|
||||||
pip install packages/*/dist/*.whl
|
|
||||||
|
|
||||||
# Basic import test
|
|
||||||
python -c "import leann; print('✅ LEANN imported successfully')"
|
|
||||||
python -c "import leann_backend_hnsw; print('✅ HNSW backend imported')"
|
|
||||||
python -c "import leann_backend_diskann; print('✅ DiskANN backend imported')"
|
|
||||||
|
|
||||||
- name: Upload artifacts
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: wheels-${{ matrix.os }}-py${{ matrix.python-version }}
|
|
||||||
path: packages/*/dist/*.whl
|
|
||||||
retention-days: 7
|
|
||||||
|
|
||||||
# Summary job to ensure all builds pass
|
|
||||||
ci-success:
|
|
||||||
needs: build-test
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: CI Success
|
|
||||||
run: |
|
|
||||||
echo "✅ All CI builds passed!"
|
|
||||||
echo "Ready for manual release when needed."
|
|
||||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
name: Link Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main, master ]
|
||||||
|
pull_request:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 3 * * 1"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
link-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: lycheeverse/lychee-action@v2
|
||||||
|
with:
|
||||||
|
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
245
.github/workflows/release-manual.yml
vendored
245
.github/workflows/release-manual.yml
vendored
@@ -1,194 +1,129 @@
|
|||||||
name: Manual Release
|
name: Release
|
||||||
|
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
version:
|
version:
|
||||||
description: 'Version to release (e.g., 0.1.1)'
|
description: 'Version to release (e.g., 0.1.2)'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
test_pypi:
|
|
||||||
description: 'Test on TestPyPI first'
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: true
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
validate-and-release:
|
update-version:
|
||||||
|
name: Update Version
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
actions: read
|
outputs:
|
||||||
|
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Check CI status
|
- name: Validate version
|
||||||
run: |
|
run: |
|
||||||
echo "ℹ️ This workflow will download build artifacts from the latest CI run."
|
# Remove 'v' prefix if present for validation
|
||||||
echo " CI must have completed successfully on the current commit."
|
VERSION_CLEAN="${{ inputs.version }}"
|
||||||
echo ""
|
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||||
|
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||||
- name: Validate version format
|
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||||
run: |
|
|
||||||
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
|
||||||
echo "❌ Invalid version format. Use semantic versioning (e.g., 0.1.1)"
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
echo "✅ Version format valid: ${{ inputs.version }}"
|
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||||
|
|
||||||
- name: Check if version already exists
|
- name: Update versions and push
|
||||||
|
id: push
|
||||||
run: |
|
run: |
|
||||||
if git tag | grep -q "^v${{ inputs.version }}$"; then
|
# Check current version
|
||||||
echo "❌ Version v${{ inputs.version }} already exists!"
|
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||||
exit 1
|
echo "Current version: $CURRENT_VERSION"
|
||||||
fi
|
echo "Target version: ${{ inputs.version }}"
|
||||||
echo "✅ Version is new"
|
|
||||||
|
|
||||||
- name: Set up Python
|
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||||
uses: actions/setup-python@v5
|
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
else
|
||||||
|
./scripts/bump_version.sh ${{ inputs.version }}
|
||||||
|
git config user.name "GitHub Actions"
|
||||||
|
git config user.email "actions@github.com"
|
||||||
|
git add packages/*/pyproject.toml
|
||||||
|
git commit -m "chore: release v${{ inputs.version }}"
|
||||||
|
git push origin main
|
||||||
|
COMMIT_SHA=$(git rev-parse HEAD)
|
||||||
|
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
build-packages:
|
||||||
|
name: Build packages
|
||||||
|
needs: update-version
|
||||||
|
uses: ./.github/workflows/build-reusable.yml
|
||||||
|
with:
|
||||||
|
ref: 'main'
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish and Release
|
||||||
|
needs: [update-version, build-packages]
|
||||||
|
if: always() && needs.update-version.result == 'success' && needs.build-packages.result == 'success'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.13'
|
ref: 'main'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Download all artifacts
|
||||||
|
uses: actions/download-artifact@v4
|
||||||
|
with:
|
||||||
|
path: dist-artifacts
|
||||||
|
|
||||||
|
- name: Collect packages
|
||||||
run: |
|
run: |
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
mkdir -p dist
|
||||||
echo "$HOME/.cargo/bin" >> $GITHUB_PATH
|
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||||
|
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||||
|
|
||||||
- name: Update versions
|
echo "📦 Packages to publish:"
|
||||||
run: |
|
ls -la dist/
|
||||||
./scripts/bump_version.sh ${{ inputs.version }}
|
|
||||||
git config user.name "GitHub Actions"
|
|
||||||
git config user.email "actions@github.com"
|
|
||||||
git add packages/*/pyproject.toml
|
|
||||||
git commit -m "chore: release v${{ inputs.version }}"
|
|
||||||
|
|
||||||
- name: Get CI run ID
|
- name: Publish to PyPI
|
||||||
id: get-ci-run
|
|
||||||
run: |
|
|
||||||
# Get the latest successful CI run on the previous commit (before version bump)
|
|
||||||
COMMIT_SHA=$(git rev-parse HEAD~1)
|
|
||||||
RUN_ID=$(gh run list \
|
|
||||||
--workflow="CI - Build and Test" \
|
|
||||||
--status=success \
|
|
||||||
--commit=$COMMIT_SHA \
|
|
||||||
--json databaseId \
|
|
||||||
--jq '.[0].databaseId')
|
|
||||||
|
|
||||||
if [ -z "$RUN_ID" ]; then
|
|
||||||
echo "❌ No successful CI run found for commit $COMMIT_SHA"
|
|
||||||
echo ""
|
|
||||||
echo "This usually means:"
|
|
||||||
echo "1. CI hasn't run on the latest commit yet"
|
|
||||||
echo "2. CI failed on the latest commit"
|
|
||||||
echo ""
|
|
||||||
echo "Please ensure CI passes on main branch before releasing."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "✅ Found CI run: $RUN_ID"
|
|
||||||
echo "run-id=$RUN_ID" >> $GITHUB_OUTPUT
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Download artifacts from CI
|
|
||||||
run: |
|
|
||||||
echo "📦 Downloading artifacts from CI run ${{ steps.get-ci-run.outputs.run-id }}..."
|
|
||||||
|
|
||||||
# Download all wheel artifacts
|
|
||||||
gh run download ${{ steps.get-ci-run.outputs.run-id }} \
|
|
||||||
--pattern "wheels-*" \
|
|
||||||
--dir ./dist-downloads
|
|
||||||
|
|
||||||
# Consolidate all wheels into packages/*/dist/
|
|
||||||
mkdir -p packages/leann-core/dist
|
|
||||||
mkdir -p packages/leann-backend-hnsw/dist
|
|
||||||
mkdir -p packages/leann-backend-diskann/dist
|
|
||||||
mkdir -p packages/leann/dist
|
|
||||||
|
|
||||||
find ./dist-downloads -name "*.whl" -exec cp {} ./packages/ \;
|
|
||||||
|
|
||||||
# Move wheels to correct package directories
|
|
||||||
for wheel in packages/*.whl; do
|
|
||||||
if [[ $wheel == *"leann_core"* ]]; then
|
|
||||||
mv "$wheel" packages/leann-core/dist/
|
|
||||||
elif [[ $wheel == *"leann_backend_hnsw"* ]]; then
|
|
||||||
mv "$wheel" packages/leann-backend-hnsw/dist/
|
|
||||||
elif [[ $wheel == *"leann_backend_diskann"* ]]; then
|
|
||||||
mv "$wheel" packages/leann-backend-diskann/dist/
|
|
||||||
elif [[ $wheel == *"leann-"* ]] && [[ $wheel != *"backend"* ]] && [[ $wheel != *"core"* ]]; then
|
|
||||||
mv "$wheel" packages/leann/dist/
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
# List downloaded wheels
|
|
||||||
echo "✅ Downloaded wheels:"
|
|
||||||
find packages/*/dist -name "*.whl" -type f | sort
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Test on TestPyPI (optional)
|
|
||||||
if: inputs.test_pypi
|
|
||||||
continue-on-error: true
|
|
||||||
env:
|
env:
|
||||||
TWINE_USERNAME: __token__
|
TWINE_USERNAME: __token__
|
||||||
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
|
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
run: |
|
run: |
|
||||||
if [ -z "$TWINE_PASSWORD" ]; then
|
if [ -z "$TWINE_PASSWORD" ]; then
|
||||||
echo "⚠️ TEST_PYPI_API_TOKEN not configured, skipping TestPyPI upload"
|
echo "❌ PYPI_API_TOKEN not configured!"
|
||||||
echo " To enable TestPyPI testing, add TEST_PYPI_API_TOKEN to repository secrets"
|
exit 1
|
||||||
exit 0
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
pip install twine
|
pip install twine
|
||||||
echo "📦 Uploading to TestPyPI..."
|
twine upload dist/* --skip-existing --verbose
|
||||||
twine upload --repository testpypi packages/*/dist/* --verbose || {
|
|
||||||
echo "⚠️ TestPyPI upload failed, but continuing with release"
|
|
||||||
echo " This is optional and won't block the release"
|
|
||||||
exit 0
|
|
||||||
}
|
|
||||||
echo "✅ Test upload successful!"
|
|
||||||
echo "📋 Check packages at: https://test.pypi.org/user/your-username/"
|
|
||||||
echo ""
|
|
||||||
echo "To test installation:"
|
|
||||||
echo "pip install -i https://test.pypi.org/simple/ leann"
|
|
||||||
|
|
||||||
- name: Create and push tag
|
echo "✅ Published to PyPI!"
|
||||||
|
|
||||||
|
- name: Create release
|
||||||
run: |
|
run: |
|
||||||
git tag "v${{ inputs.version }}"
|
# Check if tag already exists
|
||||||
git push origin main
|
if git rev-parse "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
git push origin "v${{ inputs.version }}"
|
echo "⚠️ Tag v${{ inputs.version }} already exists, skipping tag creation"
|
||||||
echo "✅ Tag v${{ inputs.version }} created and pushed"
|
else
|
||||||
|
git tag "v${{ inputs.version }}"
|
||||||
|
git push origin "v${{ inputs.version }}"
|
||||||
|
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Create GitHub Release
|
# Check if release already exists
|
||||||
uses: softprops/action-gh-release@v1
|
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||||
with:
|
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||||
tag_name: v${{ inputs.version }}
|
else
|
||||||
name: Release v${{ inputs.version }}
|
gh release create "v${{ inputs.version }}" \
|
||||||
body: |
|
--title "Release v${{ inputs.version }}" \
|
||||||
## 🚀 Release v${{ inputs.version }}
|
--notes "🚀 Released to PyPI: https://pypi.org/project/leann/${{ inputs.version }}/" \
|
||||||
|
--latest
|
||||||
### What's Changed
|
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||||
See the [full changelog](https://github.com/${{ github.repository }}/compare/...v${{ inputs.version }})
|
fi
|
||||||
|
|
||||||
### Installation
|
|
||||||
```bash
|
|
||||||
pip install leann==${{ inputs.version }}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Test Installation (if using TestPyPI)
|
|
||||||
```bash
|
|
||||||
pip install -i https://test.pypi.org/simple/ leann==${{ inputs.version }}
|
|
||||||
```
|
|
||||||
draft: false
|
|
||||||
prerelease: false
|
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
- name: Trigger PyPI publish
|
|
||||||
run: |
|
|
||||||
echo "🚀 Triggering PyPI publish workflow..."
|
|
||||||
# The existing build-and-publish.yml will be triggered by the tag push
|
|
||||||
echo "✅ Release process completed! The publish workflow will run automatically."
|
|
||||||
|
|||||||
35
.gitignore
vendored
35
.gitignore
vendored
@@ -18,9 +18,12 @@ demo/experiment_results/**/*.json
|
|||||||
*.eml
|
*.eml
|
||||||
*.emlx
|
*.emlx
|
||||||
*.json
|
*.json
|
||||||
|
*.png
|
||||||
|
!.vscode/*.json
|
||||||
*.sh
|
*.sh
|
||||||
*.txt
|
*.txt
|
||||||
!CMakeLists.txt
|
!CMakeLists.txt
|
||||||
|
!llms.txt
|
||||||
latency_breakdown*.json
|
latency_breakdown*.json
|
||||||
experiment_results/eval_results/diskann/*.json
|
experiment_results/eval_results/diskann/*.json
|
||||||
aws/
|
aws/
|
||||||
@@ -34,11 +37,15 @@ build/
|
|||||||
nprobe_logs/
|
nprobe_logs/
|
||||||
micro/results
|
micro/results
|
||||||
micro/contriever-INT8
|
micro/contriever-INT8
|
||||||
examples/data/*
|
data/*
|
||||||
!examples/data/2501.14312v1 (1).pdf
|
!data/2501.14312v1 (1).pdf
|
||||||
!examples/data/2506.08276v1.pdf
|
!data/2506.08276v1.pdf
|
||||||
!examples/data/PrideandPrejudice.txt
|
!data/PrideandPrejudice.txt
|
||||||
!examples/data/README.md
|
!data/huawei_pangu.md
|
||||||
|
!data/ground_truth/
|
||||||
|
!data/indices/
|
||||||
|
!data/queries/
|
||||||
|
!data/.gitattributes
|
||||||
*.qdstrm
|
*.qdstrm
|
||||||
benchmark_results/
|
benchmark_results/
|
||||||
results/
|
results/
|
||||||
@@ -84,5 +91,21 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
|||||||
|
|
||||||
*.meta.json
|
*.meta.json
|
||||||
*.passages.json
|
*.passages.json
|
||||||
|
*.npy
|
||||||
|
*.db
|
||||||
batchtest.py
|
batchtest.py
|
||||||
|
tests/__pytest_cache__/
|
||||||
|
tests/__pycache__/
|
||||||
|
benchmarks/data/
|
||||||
|
|
||||||
|
## multi vector
|
||||||
|
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||||
|
|
||||||
|
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
|
||||||
|
# If you need to commit a specific demo PDF, remove this negation locally.
|
||||||
|
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||||
|
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||||
|
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||||
|
|
||||||
|
# AUR build directory (Arch Linux)
|
||||||
|
paru-bin/
|
||||||
|
|||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
|||||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||||
url = https://github.com/zeromq/libzmq.git
|
url = https://github.com/zeromq/libzmq.git
|
||||||
|
[submodule "packages/astchunk-leann"]
|
||||||
|
path = packages/astchunk-leann
|
||||||
|
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||||
|
|||||||
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v5.0.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: check-yaml
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: debug-statements
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix]
|
||||||
|
- id: ruff-format
|
||||||
5
.vscode/extensions.json
vendored
Normal file
5
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"recommendations": [
|
||||||
|
"charliermarsh.ruff",
|
||||||
|
]
|
||||||
|
}
|
||||||
22
.vscode/settings.json
vendored
Normal file
22
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||||
|
"python.terminal.activateEnvironment": true,
|
||||||
|
"[python]": {
|
||||||
|
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||||
|
"editor.formatOnSave": true,
|
||||||
|
"editor.codeActionsOnSave": {
|
||||||
|
"source.organizeImports": "explicit",
|
||||||
|
"source.fixAll": "explicit"
|
||||||
|
},
|
||||||
|
"editor.insertSpaces": true,
|
||||||
|
"editor.tabSize": 4
|
||||||
|
},
|
||||||
|
"ruff.enable": true,
|
||||||
|
"files.watcherExclude": {
|
||||||
|
"**/.venv/**": true,
|
||||||
|
"**/__pycache__/**": true,
|
||||||
|
"**/*.egg-info/**": true,
|
||||||
|
"**/build/**": true,
|
||||||
|
"**/dist/**": true
|
||||||
|
}
|
||||||
|
}
|
||||||
407
apps/base_rag_example.py
Normal file
407
apps/base_rag_example.py
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
"""
|
||||||
|
Base class for unified RAG examples interface.
|
||||||
|
Provides common parameters and functionality for all RAG examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import dotenv
|
||||||
|
from leann.api import LeannBuilder, LeannChat
|
||||||
|
|
||||||
|
# Optional import: older PyPI builds may not include interactive_utils
|
||||||
|
try:
|
||||||
|
from leann.interactive_utils import create_rag_session
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def create_rag_session(app_name: str, data_description: str):
|
||||||
|
class _SimpleSession:
|
||||||
|
def run_interactive_loop(self, handler):
|
||||||
|
print(f"Interactive session for {app_name}: {data_description}")
|
||||||
|
print("Interactive mode not available in this build")
|
||||||
|
|
||||||
|
return _SimpleSession()
|
||||||
|
|
||||||
|
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
|
||||||
|
# Optional import: older PyPI builds may not include settings
|
||||||
|
try:
|
||||||
|
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||||
|
except ImportError:
|
||||||
|
# Minimal fallbacks if settings helpers are unavailable
|
||||||
|
import os
|
||||||
|
|
||||||
|
def resolve_ollama_host(value: str | None) -> str | None:
|
||||||
|
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
|
||||||
|
|
||||||
|
def resolve_openai_api_key(value: str | None) -> str | None:
|
||||||
|
return value or os.getenv("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
def resolve_openai_base_url(value: str | None) -> str | None:
|
||||||
|
return value or os.getenv("OPENAI_BASE_URL")
|
||||||
|
|
||||||
|
|
||||||
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRAGExample(ABC):
|
||||||
|
"""Base class for all RAG examples with unified interface."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
default_index_name: str,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.default_index_name = default_index_name
|
||||||
|
self.parser = self._create_parser()
|
||||||
|
|
||||||
|
def _create_parser(self) -> argparse.ArgumentParser:
|
||||||
|
"""Create argument parser with common parameters."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Core parameters (all examples share these)
|
||||||
|
core_group = parser.add_argument_group("Core Parameters")
|
||||||
|
core_group.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
type=str,
|
||||||
|
default=f"./{self.default_index_name}",
|
||||||
|
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Query to run (if not provided, will run in interactive mode)",
|
||||||
|
)
|
||||||
|
# Allow subclasses to override default max_items
|
||||||
|
max_items_default = getattr(self, "max_items_default", -1)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--max-items",
|
||||||
|
type=int,
|
||||||
|
default=max_items_default,
|
||||||
|
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||||
|
)
|
||||||
|
core_group.add_argument(
|
||||||
|
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Embedding parameters
|
||||||
|
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||||
|
# Allow subclasses to override default embedding_model
|
||||||
|
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default=embedding_model_default,
|
||||||
|
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
type=str,
|
||||||
|
default="sentence-transformers",
|
||||||
|
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||||
|
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Override Ollama-compatible embedding host",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible embedding services",
|
||||||
|
)
|
||||||
|
embedding_group.add_argument(
|
||||||
|
"--embedding-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LLM parameters
|
||||||
|
llm_group = parser.add_argument_group("LLM Parameters")
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm",
|
||||||
|
type=str,
|
||||||
|
default="openai",
|
||||||
|
choices=["openai", "ollama", "hf", "simulated"],
|
||||||
|
help="LLM backend: openai, ollama, or hf (default: openai)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-host",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--thinking-budget",
|
||||||
|
type=str,
|
||||||
|
choices=["low", "medium", "high"],
|
||||||
|
default=None,
|
||||||
|
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-base",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base URL for OpenAI-compatible APIs",
|
||||||
|
)
|
||||||
|
llm_group.add_argument(
|
||||||
|
"--llm-api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# AST Chunking parameters
|
||||||
|
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--use-ast-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-chunk-overlap",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--code-file-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||||
|
)
|
||||||
|
ast_group.add_argument(
|
||||||
|
"--ast-fallback-traditional",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search parameters
|
||||||
|
search_group = parser.add_argument_group("Search Parameters")
|
||||||
|
search_group.add_argument(
|
||||||
|
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||||
|
)
|
||||||
|
search_group.add_argument(
|
||||||
|
"--search-complexity",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Search complexity for graph traversal (default: 64)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index building parameters
|
||||||
|
index_group = parser.add_argument_group("Index Building Parameters")
|
||||||
|
index_group.add_argument(
|
||||||
|
"--backend-name",
|
||||||
|
type=str,
|
||||||
|
default="hnsw",
|
||||||
|
choices=["hnsw", "diskann"],
|
||||||
|
help="Backend to use for index (default: hnsw)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--graph-degree",
|
||||||
|
type=int,
|
||||||
|
default=32,
|
||||||
|
help="Graph degree for index construction (default: 32)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--build-complexity",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Build complexity for index construction (default: 64)",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-compact",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable compact index storage",
|
||||||
|
)
|
||||||
|
index_group.add_argument(
|
||||||
|
"--no-recompute",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable embedding recomputation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add source-specific parameters
|
||||||
|
self._add_specific_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add source-specific arguments. Override in subclasses."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load data from the source. Returns list of text chunks."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_llm_config(self, args) -> dict[str, Any]:
|
||||||
|
"""Get LLM configuration based on arguments."""
|
||||||
|
config = {"type": args.llm}
|
||||||
|
|
||||||
|
if args.llm == "openai":
|
||||||
|
config["model"] = args.llm_model or "gpt-4o"
|
||||||
|
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||||
|
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||||
|
if resolved_key:
|
||||||
|
config["api_key"] = resolved_key
|
||||||
|
elif args.llm == "ollama":
|
||||||
|
config["model"] = args.llm_model or "llama3.2:1b"
|
||||||
|
config["host"] = resolve_ollama_host(args.llm_host)
|
||||||
|
elif args.llm == "hf":
|
||||||
|
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
elif args.llm == "simulated":
|
||||||
|
# Simulated LLM doesn't need additional configuration
|
||||||
|
pass
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
async def build_index(self, args, texts: list[str]) -> str:
|
||||||
|
"""Build LEANN index from texts."""
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
|
||||||
|
print(f"\n[Building Index] Creating {self.name} index...")
|
||||||
|
print(f"Total text chunks: {len(texts)}")
|
||||||
|
|
||||||
|
embedding_options: dict[str, Any] = {}
|
||||||
|
if args.embedding_mode == "ollama":
|
||||||
|
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||||
|
elif args.embedding_mode == "openai":
|
||||||
|
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||||
|
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||||
|
if resolved_embedding_key:
|
||||||
|
embedding_options["api_key"] = resolved_embedding_key
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=args.backend_name,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
embedding_options=embedding_options or None,
|
||||||
|
graph_degree=args.graph_degree,
|
||||||
|
complexity=args.build_complexity,
|
||||||
|
is_compact=not args.no_compact,
|
||||||
|
is_recompute=not args.no_recompute,
|
||||||
|
num_threads=1, # Force single-threaded mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add texts in batches for better progress tracking
|
||||||
|
batch_size = 1000
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i : i + batch_size]
|
||||||
|
for text in batch:
|
||||||
|
builder.add_text(text)
|
||||||
|
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||||
|
|
||||||
|
print("Building index structure...")
|
||||||
|
builder.build_index(index_path)
|
||||||
|
print(f"Index saved to: {index_path}")
|
||||||
|
|
||||||
|
# Register project directory so leann list can discover this index
|
||||||
|
# The index is saved as args.index_dir/index_name.leann
|
||||||
|
# We want to register the current working directory where the app is run
|
||||||
|
register_project_directory(Path.cwd())
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
async def run_interactive_chat(self, args, index_path: str):
|
||||||
|
"""Run interactive chat with the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create interactive session
|
||||||
|
session = create_rag_session(
|
||||||
|
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_query(query: str):
|
||||||
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
|
llm_kwargs = {}
|
||||||
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
query,
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
)
|
||||||
|
print(f"\nAssistant: {response}\n")
|
||||||
|
|
||||||
|
session.run_interactive_loop(handle_query)
|
||||||
|
|
||||||
|
async def run_single_query(self, args, index_path: str, query: str):
|
||||||
|
"""Run a single query against the index."""
|
||||||
|
chat = LeannChat(
|
||||||
|
index_path,
|
||||||
|
llm_config=self.get_llm_config(args),
|
||||||
|
complexity=args.search_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||||
|
|
||||||
|
# Prepare LLM kwargs with thinking budget if specified
|
||||||
|
llm_kwargs = {}
|
||||||
|
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||||
|
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||||
|
|
||||||
|
response = chat.ask(
|
||||||
|
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||||
|
)
|
||||||
|
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main entry point for the example."""
|
||||||
|
args = self.parser.parse_args()
|
||||||
|
|
||||||
|
# Check if index exists
|
||||||
|
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||||
|
index_exists = Path(args.index_dir).exists()
|
||||||
|
|
||||||
|
if not index_exists or args.force_rebuild:
|
||||||
|
# Load data and build index
|
||||||
|
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||||
|
texts = await self.load_data(args)
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
print("No data found to index!")
|
||||||
|
return
|
||||||
|
|
||||||
|
index_path = await self.build_index(args, texts)
|
||||||
|
else:
|
||||||
|
print(f"\nUsing existing index in {args.index_dir}")
|
||||||
|
|
||||||
|
# Run query or interactive mode
|
||||||
|
if args.query:
|
||||||
|
await self.run_single_query(args, index_path, args.query)
|
||||||
|
else:
|
||||||
|
await self.run_interactive_chat(args, index_path)
|
||||||
171
apps/browser_rag.py
Normal file
171
apps/browser_rag.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""
|
||||||
|
Browser History RAG example using the unified interface.
|
||||||
|
Supports Chrome browser history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
from .history_data.history import ChromeHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class BrowserRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Chrome browser history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Browser History",
|
||||||
|
description="Process and query Chrome browser history with LEANN",
|
||||||
|
default_index_name="google_history_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add browser-specific arguments."""
|
||||||
|
browser_group = parser.add_argument_group("Browser Parameters")
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chrome-profile",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--auto-find-profiles",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Automatically find all Chrome profiles (default: True)",
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
browser_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_chrome_base_path(self) -> Path:
|
||||||
|
"""Get the base Chrome profile path based on OS."""
|
||||||
|
if sys.platform == "darwin":
|
||||||
|
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||||
|
elif sys.platform.startswith("linux"):
|
||||||
|
return Path.home() / ".config" / "google-chrome"
|
||||||
|
elif sys.platform == "win32":
|
||||||
|
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||||
|
|
||||||
|
def _find_chrome_profiles(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Chrome profiles."""
|
||||||
|
base_path = self._get_chrome_base_path()
|
||||||
|
if not base_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
profiles = []
|
||||||
|
|
||||||
|
# Check Default profile
|
||||||
|
default_profile = base_path / "Default"
|
||||||
|
if default_profile.exists() and (default_profile / "History").exists():
|
||||||
|
profiles.append(default_profile)
|
||||||
|
|
||||||
|
# Check numbered profiles
|
||||||
|
for item in base_path.iterdir():
|
||||||
|
if item.is_dir() and item.name.startswith("Profile "):
|
||||||
|
if (item / "History").exists():
|
||||||
|
profiles.append(item)
|
||||||
|
|
||||||
|
return profiles
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load browser history and convert to text chunks."""
|
||||||
|
# Determine Chrome profiles
|
||||||
|
if args.chrome_profile and not args.auto_find_profiles:
|
||||||
|
profile_dirs = [Path(args.chrome_profile)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Chrome profiles...")
|
||||||
|
profile_dirs = self._find_chrome_profiles()
|
||||||
|
|
||||||
|
# If specific profile given, filter to just that one
|
||||||
|
if args.chrome_profile:
|
||||||
|
profile_path = Path(args.chrome_profile)
|
||||||
|
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||||
|
|
||||||
|
if not profile_dirs:
|
||||||
|
print("No Chrome profiles found!")
|
||||||
|
print("Please specify --chrome-profile manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = ChromeHistoryReader()
|
||||||
|
|
||||||
|
# Process each profile
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, profile_dir in enumerate(profile_dirs):
|
||||||
|
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per profile
|
||||||
|
max_per_profile = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_profile = remaining
|
||||||
|
|
||||||
|
# Load history
|
||||||
|
documents = reader.load_data(
|
||||||
|
chrome_profile_path=str(profile_dir),
|
||||||
|
max_count=max_per_profile,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} history entries from this profile")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {profile_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No browser history found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for browser history RAG
|
||||||
|
print("\n🌐 Browser History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What websites did I visit about machine learning?'")
|
||||||
|
print("- 'Find my search history about programming'")
|
||||||
|
print("- 'What YouTube videos did I watch recently?'")
|
||||||
|
print("- 'Show me websites about travel planning'")
|
||||||
|
print("\nNote: Make sure Chrome is closed before running\n")
|
||||||
|
|
||||||
|
rag = BrowserRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
0
apps/chatgpt_data/__init__.py
Normal file
0
apps/chatgpt_data/__init__.py
Normal file
413
apps/chatgpt_data/chatgpt_reader.py
Normal file
413
apps/chatgpt_data/chatgpt_reader.py
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
"""
|
||||||
|
ChatGPT export data reader.
|
||||||
|
|
||||||
|
Reads and processes ChatGPT export data from chat.html files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTReader(BaseReader):
|
||||||
|
"""
|
||||||
|
ChatGPT export data reader.
|
||||||
|
|
||||||
|
Reads ChatGPT conversation data from exported chat.html files or zip archives.
|
||||||
|
Processes conversations into structured documents with metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # noqa
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||||
|
|
||||||
|
self.concatenate_conversations = concatenate_conversations
|
||||||
|
|
||||||
|
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
|
||||||
|
"""
|
||||||
|
Extract chat.html from ChatGPT export zip file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_path: Path to the ChatGPT export zip file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML content as string, or None if not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with ZipFile(zip_path, "r") as zip_file:
|
||||||
|
# Look for chat.html or conversations.html
|
||||||
|
html_files = [
|
||||||
|
f
|
||||||
|
for f in zip_file.namelist()
|
||||||
|
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
|
||||||
|
]
|
||||||
|
|
||||||
|
if not html_files:
|
||||||
|
print(f"No HTML chat file found in {zip_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Use the first HTML file found
|
||||||
|
html_file = html_files[0]
|
||||||
|
print(f"Found HTML file: {html_file}")
|
||||||
|
|
||||||
|
with zip_file.open(html_file) as f:
|
||||||
|
return f.read().decode("utf-8", errors="ignore")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error extracting HTML from zip {zip_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Parse ChatGPT HTML export to extract conversations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
html_content: HTML content from ChatGPT export
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of conversation dictionaries
|
||||||
|
"""
|
||||||
|
soup = BeautifulSoup(html_content, "html.parser")
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
# Try different possible structures for ChatGPT exports
|
||||||
|
# Structure 1: Look for conversation containers
|
||||||
|
conversation_containers = soup.find_all(
|
||||||
|
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not conversation_containers:
|
||||||
|
# Structure 2: Look for message containers directly
|
||||||
|
conversation_containers = [soup] # Use the entire document as one conversation
|
||||||
|
|
||||||
|
for container in conversation_containers:
|
||||||
|
conversation = self._extract_conversation_from_container(container)
|
||||||
|
if conversation and conversation.get("messages"):
|
||||||
|
conversations.append(conversation)
|
||||||
|
|
||||||
|
# If no structured conversations found, try to extract all text as one conversation
|
||||||
|
if not conversations:
|
||||||
|
all_text = soup.get_text(separator="\n", strip=True)
|
||||||
|
if all_text:
|
||||||
|
conversations.append(
|
||||||
|
{
|
||||||
|
"title": "ChatGPT Conversation",
|
||||||
|
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
|
||||||
|
"timestamp": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
def _extract_conversation_from_container(self, container) -> dict | None:
|
||||||
|
"""
|
||||||
|
Extract conversation data from a container element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
container: BeautifulSoup element containing conversation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with conversation data or None
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# Look for message elements with various possible structures
|
||||||
|
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
|
||||||
|
|
||||||
|
for selector in message_selectors:
|
||||||
|
message_elements = container.select(selector)
|
||||||
|
if message_elements:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
message_elements = []
|
||||||
|
|
||||||
|
# If no structured messages found, treat the entire container as one message
|
||||||
|
if not message_elements:
|
||||||
|
text_content = container.get_text(separator="\n", strip=True)
|
||||||
|
if text_content:
|
||||||
|
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
|
||||||
|
else:
|
||||||
|
for element in message_elements:
|
||||||
|
message = self._extract_message_from_element(element)
|
||||||
|
if message:
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try to extract conversation title
|
||||||
|
title_element = container.find(["h1", "h2", "h3", "title"])
|
||||||
|
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
|
||||||
|
|
||||||
|
# Try to extract timestamp from various possible locations
|
||||||
|
timestamp = self._extract_timestamp_from_container(container)
|
||||||
|
|
||||||
|
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||||
|
|
||||||
|
def _extract_message_from_element(self, element) -> dict | None:
|
||||||
|
"""
|
||||||
|
Extract message data from an element.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
element: BeautifulSoup element containing message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with message data or None
|
||||||
|
"""
|
||||||
|
text_content = element.get_text(separator=" ", strip=True)
|
||||||
|
|
||||||
|
# Skip empty or very short messages
|
||||||
|
if not text_content or len(text_content.strip()) < 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try to determine role (user/assistant) from class names or content
|
||||||
|
role = "mixed" # Default role
|
||||||
|
|
||||||
|
class_names = " ".join(element.get("class", [])).lower()
|
||||||
|
if "user" in class_names or "human" in class_names:
|
||||||
|
role = "user"
|
||||||
|
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
|
||||||
|
role = "assistant"
|
||||||
|
elif text_content.lower().startswith(("you:", "user:", "me:")):
|
||||||
|
role = "user"
|
||||||
|
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
|
||||||
|
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
|
||||||
|
role = "assistant"
|
||||||
|
text_content = re.sub(
|
||||||
|
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to extract timestamp
|
||||||
|
timestamp = self._extract_timestamp_from_element(element)
|
||||||
|
|
||||||
|
return {"role": role, "content": text_content, "timestamp": timestamp}
|
||||||
|
|
||||||
|
def _extract_timestamp_from_element(self, element) -> str | None:
|
||||||
|
"""Extract timestamp from element."""
|
||||||
|
# Look for timestamp in various attributes and child elements
|
||||||
|
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
|
||||||
|
for attr in timestamp_attrs:
|
||||||
|
if element.get(attr):
|
||||||
|
return element.get(attr)
|
||||||
|
|
||||||
|
# Look for time elements
|
||||||
|
time_element = element.find("time")
|
||||||
|
if time_element:
|
||||||
|
return time_element.get("datetime") or time_element.get_text(strip=True)
|
||||||
|
|
||||||
|
# Look for date-like text patterns
|
||||||
|
text = element.get_text()
|
||||||
|
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
|
||||||
|
|
||||||
|
for pattern in date_patterns:
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if match:
|
||||||
|
return match.group()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_timestamp_from_container(self, container) -> str | None:
|
||||||
|
"""Extract timestamp from conversation container."""
|
||||||
|
return self._extract_timestamp_from_element(container)
|
||||||
|
|
||||||
|
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||||
|
"""
|
||||||
|
Create concatenated content from conversation messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: Dictionary containing conversation data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted concatenated content
|
||||||
|
"""
|
||||||
|
title = conversation.get("title", "ChatGPT Conversation")
|
||||||
|
messages = conversation.get("messages", [])
|
||||||
|
timestamp = conversation.get("timestamp", "Unknown")
|
||||||
|
|
||||||
|
# Build message content
|
||||||
|
message_parts = []
|
||||||
|
for message in messages:
|
||||||
|
role = message.get("role", "mixed")
|
||||||
|
content = message.get("content", "")
|
||||||
|
msg_timestamp = message.get("timestamp", "")
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
prefix = "[You]"
|
||||||
|
elif role == "assistant":
|
||||||
|
prefix = "[ChatGPT]"
|
||||||
|
else:
|
||||||
|
prefix = "[Message]"
|
||||||
|
|
||||||
|
# Add timestamp if available
|
||||||
|
if msg_timestamp:
|
||||||
|
prefix += f" ({msg_timestamp})"
|
||||||
|
|
||||||
|
message_parts.append(f"{prefix}: {content}")
|
||||||
|
|
||||||
|
concatenated_text = "\n\n".join(message_parts)
|
||||||
|
|
||||||
|
# Create final document content
|
||||||
|
doc_content = f"""Conversation: {title}
|
||||||
|
Date: {timestamp}
|
||||||
|
Messages ({len(messages)} messages):
|
||||||
|
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
return doc_content
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load ChatGPT export data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing ChatGPT export files or path to specific file
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum number of conversations to process
|
||||||
|
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
|
||||||
|
include_metadata (bool): Whether to include metadata in documents
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", -1)
|
||||||
|
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
|
||||||
|
include_metadata = load_kwargs.get("include_metadata", True)
|
||||||
|
|
||||||
|
if not chatgpt_export_path:
|
||||||
|
print("No ChatGPT export path provided")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
export_path = Path(chatgpt_export_path)
|
||||||
|
|
||||||
|
if not export_path.exists():
|
||||||
|
print(f"ChatGPT export path not found: {export_path}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
html_content = None
|
||||||
|
|
||||||
|
# Handle different input types
|
||||||
|
if export_path.is_file():
|
||||||
|
if export_path.suffix.lower() == ".zip":
|
||||||
|
# Extract HTML from zip file
|
||||||
|
html_content = self._extract_html_from_zip(export_path)
|
||||||
|
elif export_path.suffix.lower() == ".html":
|
||||||
|
# Read HTML file directly
|
||||||
|
try:
|
||||||
|
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||||
|
html_content = f.read()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading HTML file {export_path}: {e}")
|
||||||
|
return docs
|
||||||
|
else:
|
||||||
|
print(f"Unsupported file type: {export_path.suffix}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
elif export_path.is_dir():
|
||||||
|
# Look for HTML files in directory
|
||||||
|
html_files = list(export_path.glob("*.html"))
|
||||||
|
zip_files = list(export_path.glob("*.zip"))
|
||||||
|
|
||||||
|
if html_files:
|
||||||
|
# Use first HTML file found
|
||||||
|
html_file = html_files[0]
|
||||||
|
print(f"Found HTML file: {html_file}")
|
||||||
|
try:
|
||||||
|
with open(html_file, encoding="utf-8", errors="ignore") as f:
|
||||||
|
html_content = f.read()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading HTML file {html_file}: {e}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
elif zip_files:
|
||||||
|
# Use first zip file found
|
||||||
|
zip_file = zip_files[0]
|
||||||
|
print(f"Found zip file: {zip_file}")
|
||||||
|
html_content = self._extract_html_from_zip(zip_file)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"No HTML or zip files found in {export_path}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
if not html_content:
|
||||||
|
print("No HTML content found to process")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
# Parse conversations from HTML
|
||||||
|
print("Parsing ChatGPT conversations from HTML...")
|
||||||
|
conversations = self._parse_chatgpt_html(html_content)
|
||||||
|
|
||||||
|
if not conversations:
|
||||||
|
print("No conversations found in HTML content")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
print(f"Found {len(conversations)} conversations")
|
||||||
|
|
||||||
|
# Process conversations into documents
|
||||||
|
count = 0
|
||||||
|
for conversation in conversations:
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.concatenate_conversations:
|
||||||
|
# Create one document per conversation with concatenated messages
|
||||||
|
doc_content = self._create_concatenated_content(conversation)
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
if include_metadata:
|
||||||
|
metadata = {
|
||||||
|
"title": conversation.get("title", "ChatGPT Conversation"),
|
||||||
|
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||||
|
"message_count": len(conversation.get("messages", [])),
|
||||||
|
"source": "ChatGPT Export",
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Create separate documents for each message
|
||||||
|
for message in conversation.get("messages", []):
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
role = message.get("role", "mixed")
|
||||||
|
content = message.get("content", "")
|
||||||
|
msg_timestamp = message.get("timestamp", "")
|
||||||
|
|
||||||
|
if not content.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create document content with context
|
||||||
|
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
|
||||||
|
Role: {role}
|
||||||
|
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||||
|
Message: {content}
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
if include_metadata:
|
||||||
|
metadata = {
|
||||||
|
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
|
||||||
|
"role": role,
|
||||||
|
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||||
|
"source": "ChatGPT Export",
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(f"Created {len(docs)} documents from ChatGPT export")
|
||||||
|
return docs
|
||||||
186
apps/chatgpt_rag.py
Normal file
186
apps/chatgpt_rag.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""
|
||||||
|
ChatGPT RAG example using the unified interface.
|
||||||
|
Supports ChatGPT export data from chat.html files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
from .chatgpt_data.chatgpt_reader import ChatGPTReader
|
||||||
|
|
||||||
|
|
||||||
|
class ChatGPTRAG(BaseRAGExample):
|
||||||
|
"""RAG example for ChatGPT conversation data."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Process all conversations by default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="ChatGPT",
|
||||||
|
description="Process and query ChatGPT conversation exports with LEANN",
|
||||||
|
default_index_name="chatgpt_conversations_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add ChatGPT-specific arguments."""
|
||||||
|
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
|
||||||
|
chatgpt_group.add_argument(
|
||||||
|
"--export-path",
|
||||||
|
type=str,
|
||||||
|
default="./chatgpt_export",
|
||||||
|
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
|
||||||
|
)
|
||||||
|
chatgpt_group.add_argument(
|
||||||
|
"--concatenate-conversations",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Concatenate messages within conversations for better context (default: True)",
|
||||||
|
)
|
||||||
|
chatgpt_group.add_argument(
|
||||||
|
"--separate-messages",
|
||||||
|
action="store_true",
|
||||||
|
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||||
|
)
|
||||||
|
chatgpt_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||||
|
)
|
||||||
|
chatgpt_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Find ChatGPT export files in the given path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_path: Path to search for exports
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of paths to ChatGPT export files
|
||||||
|
"""
|
||||||
|
export_files = []
|
||||||
|
|
||||||
|
if export_path.is_file():
|
||||||
|
if export_path.suffix.lower() in [".zip", ".html"]:
|
||||||
|
export_files.append(export_path)
|
||||||
|
elif export_path.is_dir():
|
||||||
|
# Look for zip and html files
|
||||||
|
export_files.extend(export_path.glob("*.zip"))
|
||||||
|
export_files.extend(export_path.glob("*.html"))
|
||||||
|
|
||||||
|
return export_files
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load ChatGPT export data and convert to text chunks."""
|
||||||
|
export_path = Path(args.export_path)
|
||||||
|
|
||||||
|
if not export_path.exists():
|
||||||
|
print(f"ChatGPT export path not found: {export_path}")
|
||||||
|
print(
|
||||||
|
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
|
||||||
|
)
|
||||||
|
print("\nTo export your ChatGPT data:")
|
||||||
|
print("1. Sign in to ChatGPT")
|
||||||
|
print("2. Click on your profile icon → Settings → Data Controls")
|
||||||
|
print("3. Click 'Export' under Export Data")
|
||||||
|
print("4. Download the zip file from the email link")
|
||||||
|
print("5. Extract or place the file/directory at the specified path")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find export files
|
||||||
|
export_files = self._find_chatgpt_exports(export_path)
|
||||||
|
|
||||||
|
if not export_files:
|
||||||
|
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(export_files)} ChatGPT export files")
|
||||||
|
|
||||||
|
# Create reader with appropriate settings
|
||||||
|
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||||
|
reader = ChatGPTReader(concatenate_conversations=concatenate)
|
||||||
|
|
||||||
|
# Process each export file
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, export_file in enumerate(export_files):
|
||||||
|
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per file
|
||||||
|
max_per_file = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_file = remaining
|
||||||
|
|
||||||
|
# Load conversations
|
||||||
|
documents = reader.load_data(
|
||||||
|
chatgpt_export_path=str(export_file),
|
||||||
|
max_count=max_per_file,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} conversations from this file")
|
||||||
|
else:
|
||||||
|
print(f"No conversations loaded from {export_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No conversations found to process!")
|
||||||
|
print("\nTroubleshooting:")
|
||||||
|
print("- Ensure the export file is a valid ChatGPT export")
|
||||||
|
print("- Check that the HTML file contains conversation data")
|
||||||
|
print("- Try extracting the zip file and pointing to the HTML file directly")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||||
|
print("Now starting to split into text chunks... this may take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for ChatGPT RAG
|
||||||
|
print("\n🤖 ChatGPT RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What did I ask about Python programming?'")
|
||||||
|
print("- 'Show me conversations about machine learning'")
|
||||||
|
print("- 'Find discussions about travel planning'")
|
||||||
|
print("- 'What advice did ChatGPT give me about career development?'")
|
||||||
|
print("- 'Search for conversations about cooking recipes'")
|
||||||
|
print("\nTo get started:")
|
||||||
|
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
|
||||||
|
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
|
||||||
|
print("3. Run this script to build your personal ChatGPT knowledge base!")
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = ChatGPTRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
47
apps/chunking/__init__.py
Normal file
47
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Unified chunking utilities facade.
|
||||||
|
|
||||||
|
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||||
|
that both repo apps (importing `chunking`) and installed wheels share one
|
||||||
|
single implementation. When running from the repo without installation, it
|
||||||
|
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
_traditional_chunks_as_dicts,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||||
|
repo_root = Path(__file__).resolve().parents[2]
|
||||||
|
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||||
|
if leann_src.exists():
|
||||||
|
sys.path.insert(0, str(leann_src))
|
||||||
|
from leann.chunking_utils import (
|
||||||
|
CODE_EXTENSIONS,
|
||||||
|
_traditional_chunks_as_dicts,
|
||||||
|
create_ast_chunks,
|
||||||
|
create_text_chunks,
|
||||||
|
create_traditional_chunks,
|
||||||
|
detect_code_files,
|
||||||
|
get_language_from_extension,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CODE_EXTENSIONS",
|
||||||
|
"_traditional_chunks_as_dicts",
|
||||||
|
"create_ast_chunks",
|
||||||
|
"create_text_chunks",
|
||||||
|
"create_traditional_chunks",
|
||||||
|
"detect_code_files",
|
||||||
|
"get_language_from_extension",
|
||||||
|
]
|
||||||
0
apps/claude_data/__init__.py
Normal file
0
apps/claude_data/__init__.py
Normal file
420
apps/claude_data/claude_reader.py
Normal file
420
apps/claude_data/claude_reader.py
Normal file
@@ -0,0 +1,420 @@
|
|||||||
|
"""
|
||||||
|
Claude export data reader.
|
||||||
|
|
||||||
|
Reads and processes Claude conversation data from exported JSON files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from zipfile import ZipFile
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Claude export data reader.
|
||||||
|
|
||||||
|
Reads Claude conversation data from exported JSON files or zip archives.
|
||||||
|
Processes conversations into structured documents with metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||||
|
"""
|
||||||
|
self.concatenate_conversations = concatenate_conversations
|
||||||
|
|
||||||
|
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extract JSON files from Claude export zip file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zip_path: Path to the Claude export zip file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of JSON content strings, or empty list if not found
|
||||||
|
"""
|
||||||
|
json_contents = []
|
||||||
|
try:
|
||||||
|
with ZipFile(zip_path, "r") as zip_file:
|
||||||
|
# Look for JSON files
|
||||||
|
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
|
||||||
|
|
||||||
|
if not json_files:
|
||||||
|
print(f"No JSON files found in {zip_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(json_files)} JSON files in archive")
|
||||||
|
|
||||||
|
for json_file in json_files:
|
||||||
|
with zip_file.open(json_file) as f:
|
||||||
|
content = f.read().decode("utf-8", errors="ignore")
|
||||||
|
json_contents.append(content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error extracting JSON from zip {zip_path}: {e}")
|
||||||
|
|
||||||
|
return json_contents
|
||||||
|
|
||||||
|
def _parse_claude_json(self, json_content: str) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Parse Claude JSON export to extract conversations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_content: JSON content from Claude export
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of conversation dictionaries
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(json_content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error parsing JSON: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
conversations = []
|
||||||
|
|
||||||
|
# Handle different possible JSON structures
|
||||||
|
if isinstance(data, list):
|
||||||
|
# If data is a list of conversations
|
||||||
|
for item in data:
|
||||||
|
conversation = self._extract_conversation_from_json(item)
|
||||||
|
if conversation:
|
||||||
|
conversations.append(conversation)
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
# Check for common structures
|
||||||
|
if "conversations" in data:
|
||||||
|
# Structure: {"conversations": [...]}
|
||||||
|
for item in data["conversations"]:
|
||||||
|
conversation = self._extract_conversation_from_json(item)
|
||||||
|
if conversation:
|
||||||
|
conversations.append(conversation)
|
||||||
|
elif "messages" in data:
|
||||||
|
# Single conversation with messages
|
||||||
|
conversation = self._extract_conversation_from_json(data)
|
||||||
|
if conversation:
|
||||||
|
conversations.append(conversation)
|
||||||
|
else:
|
||||||
|
# Try to treat the whole object as a conversation
|
||||||
|
conversation = self._extract_conversation_from_json(data)
|
||||||
|
if conversation:
|
||||||
|
conversations.append(conversation)
|
||||||
|
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
|
||||||
|
"""
|
||||||
|
Extract conversation data from a JSON object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conv_data: Dictionary containing conversation data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with conversation data or None
|
||||||
|
"""
|
||||||
|
if not isinstance(conv_data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# Look for messages in various possible structures
|
||||||
|
message_sources = []
|
||||||
|
if "messages" in conv_data:
|
||||||
|
message_sources = conv_data["messages"]
|
||||||
|
elif "chat" in conv_data:
|
||||||
|
message_sources = conv_data["chat"]
|
||||||
|
elif "conversation" in conv_data:
|
||||||
|
message_sources = conv_data["conversation"]
|
||||||
|
else:
|
||||||
|
# If no clear message structure, try to extract from the object itself
|
||||||
|
if "content" in conv_data and "role" in conv_data:
|
||||||
|
message_sources = [conv_data]
|
||||||
|
|
||||||
|
for msg_data in message_sources:
|
||||||
|
message = self._extract_message_from_json(msg_data)
|
||||||
|
if message:
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract conversation metadata
|
||||||
|
title = self._extract_title_from_conversation(conv_data, messages)
|
||||||
|
timestamp = self._extract_timestamp_from_conversation(conv_data)
|
||||||
|
|
||||||
|
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||||
|
|
||||||
|
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
|
||||||
|
"""
|
||||||
|
Extract message data from a JSON message object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg_data: Dictionary containing message data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with message data or None
|
||||||
|
"""
|
||||||
|
if not isinstance(msg_data, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract content from various possible fields
|
||||||
|
content = ""
|
||||||
|
content_fields = ["content", "text", "message", "body"]
|
||||||
|
for field in content_fields:
|
||||||
|
if msg_data.get(field):
|
||||||
|
content = str(msg_data[field])
|
||||||
|
break
|
||||||
|
|
||||||
|
if not content or len(content.strip()) < 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract role (user/assistant/human/ai/claude)
|
||||||
|
role = "mixed" # Default role
|
||||||
|
role_fields = ["role", "sender", "from", "author", "type"]
|
||||||
|
for field in role_fields:
|
||||||
|
if msg_data.get(field):
|
||||||
|
role_value = str(msg_data[field]).lower()
|
||||||
|
if role_value in ["user", "human", "person"]:
|
||||||
|
role = "user"
|
||||||
|
elif role_value in ["assistant", "ai", "claude", "bot"]:
|
||||||
|
role = "assistant"
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract timestamp
|
||||||
|
timestamp = self._extract_timestamp_from_message(msg_data)
|
||||||
|
|
||||||
|
return {"role": role, "content": content, "timestamp": timestamp}
|
||||||
|
|
||||||
|
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
|
||||||
|
"""Extract timestamp from message data."""
|
||||||
|
timestamp_fields = ["timestamp", "created_at", "date", "time"]
|
||||||
|
for field in timestamp_fields:
|
||||||
|
if msg_data.get(field):
|
||||||
|
return str(msg_data[field])
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
|
||||||
|
"""Extract timestamp from conversation data."""
|
||||||
|
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
|
||||||
|
for field in timestamp_fields:
|
||||||
|
if conv_data.get(field):
|
||||||
|
return str(conv_data[field])
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
|
||||||
|
"""Extract or generate title for conversation."""
|
||||||
|
# Try to find explicit title
|
||||||
|
title_fields = ["title", "name", "subject", "topic"]
|
||||||
|
for field in title_fields:
|
||||||
|
if conv_data.get(field):
|
||||||
|
return str(conv_data[field])
|
||||||
|
|
||||||
|
# Generate title from first user message
|
||||||
|
for message in messages:
|
||||||
|
if message.get("role") == "user":
|
||||||
|
content = message.get("content", "")
|
||||||
|
if content:
|
||||||
|
# Use first 50 characters as title
|
||||||
|
title = content[:50].strip()
|
||||||
|
if len(content) > 50:
|
||||||
|
title += "..."
|
||||||
|
return title
|
||||||
|
|
||||||
|
return "Claude Conversation"
|
||||||
|
|
||||||
|
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||||
|
"""
|
||||||
|
Create concatenated content from conversation messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation: Dictionary containing conversation data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted concatenated content
|
||||||
|
"""
|
||||||
|
title = conversation.get("title", "Claude Conversation")
|
||||||
|
messages = conversation.get("messages", [])
|
||||||
|
timestamp = conversation.get("timestamp", "Unknown")
|
||||||
|
|
||||||
|
# Build message content
|
||||||
|
message_parts = []
|
||||||
|
for message in messages:
|
||||||
|
role = message.get("role", "mixed")
|
||||||
|
content = message.get("content", "")
|
||||||
|
msg_timestamp = message.get("timestamp", "")
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
prefix = "[You]"
|
||||||
|
elif role == "assistant":
|
||||||
|
prefix = "[Claude]"
|
||||||
|
else:
|
||||||
|
prefix = "[Message]"
|
||||||
|
|
||||||
|
# Add timestamp if available
|
||||||
|
if msg_timestamp:
|
||||||
|
prefix += f" ({msg_timestamp})"
|
||||||
|
|
||||||
|
message_parts.append(f"{prefix}: {content}")
|
||||||
|
|
||||||
|
concatenated_text = "\n\n".join(message_parts)
|
||||||
|
|
||||||
|
# Create final document content
|
||||||
|
doc_content = f"""Conversation: {title}
|
||||||
|
Date: {timestamp}
|
||||||
|
Messages ({len(messages)} messages):
|
||||||
|
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
return doc_content
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load Claude export data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing Claude export files or path to specific file
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum number of conversations to process
|
||||||
|
claude_export_path (str): Specific path to Claude export file/directory
|
||||||
|
include_metadata (bool): Whether to include metadata in documents
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", -1)
|
||||||
|
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
|
||||||
|
include_metadata = load_kwargs.get("include_metadata", True)
|
||||||
|
|
||||||
|
if not claude_export_path:
|
||||||
|
print("No Claude export path provided")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
export_path = Path(claude_export_path)
|
||||||
|
|
||||||
|
if not export_path.exists():
|
||||||
|
print(f"Claude export path not found: {export_path}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
json_contents = []
|
||||||
|
|
||||||
|
# Handle different input types
|
||||||
|
if export_path.is_file():
|
||||||
|
if export_path.suffix.lower() == ".zip":
|
||||||
|
# Extract JSON from zip file
|
||||||
|
json_contents = self._extract_json_from_zip(export_path)
|
||||||
|
elif export_path.suffix.lower() == ".json":
|
||||||
|
# Read JSON file directly
|
||||||
|
try:
|
||||||
|
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||||
|
json_contents.append(f.read())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading JSON file {export_path}: {e}")
|
||||||
|
return docs
|
||||||
|
else:
|
||||||
|
print(f"Unsupported file type: {export_path.suffix}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
elif export_path.is_dir():
|
||||||
|
# Look for JSON files in directory
|
||||||
|
json_files = list(export_path.glob("*.json"))
|
||||||
|
zip_files = list(export_path.glob("*.zip"))
|
||||||
|
|
||||||
|
if json_files:
|
||||||
|
print(f"Found {len(json_files)} JSON files in directory")
|
||||||
|
for json_file in json_files:
|
||||||
|
try:
|
||||||
|
with open(json_file, encoding="utf-8", errors="ignore") as f:
|
||||||
|
json_contents.append(f.read())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading JSON file {json_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if zip_files:
|
||||||
|
print(f"Found {len(zip_files)} ZIP files in directory")
|
||||||
|
for zip_file in zip_files:
|
||||||
|
zip_contents = self._extract_json_from_zip(zip_file)
|
||||||
|
json_contents.extend(zip_contents)
|
||||||
|
|
||||||
|
if not json_files and not zip_files:
|
||||||
|
print(f"No JSON or ZIP files found in {export_path}")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
if not json_contents:
|
||||||
|
print("No JSON content found to process")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
# Parse conversations from JSON content
|
||||||
|
print("Parsing Claude conversations from JSON...")
|
||||||
|
all_conversations = []
|
||||||
|
for json_content in json_contents:
|
||||||
|
conversations = self._parse_claude_json(json_content)
|
||||||
|
all_conversations.extend(conversations)
|
||||||
|
|
||||||
|
if not all_conversations:
|
||||||
|
print("No conversations found in JSON content")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
print(f"Found {len(all_conversations)} conversations")
|
||||||
|
|
||||||
|
# Process conversations into documents
|
||||||
|
count = 0
|
||||||
|
for conversation in all_conversations:
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.concatenate_conversations:
|
||||||
|
# Create one document per conversation with concatenated messages
|
||||||
|
doc_content = self._create_concatenated_content(conversation)
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
if include_metadata:
|
||||||
|
metadata = {
|
||||||
|
"title": conversation.get("title", "Claude Conversation"),
|
||||||
|
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||||
|
"message_count": len(conversation.get("messages", [])),
|
||||||
|
"source": "Claude Export",
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Create separate documents for each message
|
||||||
|
for message in conversation.get("messages", []):
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
role = message.get("role", "mixed")
|
||||||
|
content = message.get("content", "")
|
||||||
|
msg_timestamp = message.get("timestamp", "")
|
||||||
|
|
||||||
|
if not content.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create document content with context
|
||||||
|
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
|
||||||
|
Role: {role}
|
||||||
|
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||||
|
Message: {content}
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
if include_metadata:
|
||||||
|
metadata = {
|
||||||
|
"conversation_title": conversation.get("title", "Claude Conversation"),
|
||||||
|
"role": role,
|
||||||
|
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||||
|
"source": "Claude Export",
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=doc_content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(f"Created {len(docs)} documents from Claude export")
|
||||||
|
return docs
|
||||||
189
apps/claude_rag.py
Normal file
189
apps/claude_rag.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Claude RAG example using the unified interface.
|
||||||
|
Supports Claude export data from JSON files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
from .claude_data.claude_reader import ClaudeReader
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Claude conversation data."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Process all conversations by default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Claude",
|
||||||
|
description="Process and query Claude conversation exports with LEANN",
|
||||||
|
default_index_name="claude_conversations_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add Claude-specific arguments."""
|
||||||
|
claude_group = parser.add_argument_group("Claude Parameters")
|
||||||
|
claude_group.add_argument(
|
||||||
|
"--export-path",
|
||||||
|
type=str,
|
||||||
|
default="./claude_export",
|
||||||
|
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
|
||||||
|
)
|
||||||
|
claude_group.add_argument(
|
||||||
|
"--concatenate-conversations",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Concatenate messages within conversations for better context (default: True)",
|
||||||
|
)
|
||||||
|
claude_group.add_argument(
|
||||||
|
"--separate-messages",
|
||||||
|
action="store_true",
|
||||||
|
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||||
|
)
|
||||||
|
claude_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||||
|
)
|
||||||
|
claude_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_claude_exports(self, export_path: Path) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Find Claude export files in the given path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_path: Path to search for exports
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of paths to Claude export files
|
||||||
|
"""
|
||||||
|
export_files = []
|
||||||
|
|
||||||
|
if export_path.is_file():
|
||||||
|
if export_path.suffix.lower() in [".zip", ".json"]:
|
||||||
|
export_files.append(export_path)
|
||||||
|
elif export_path.is_dir():
|
||||||
|
# Look for zip and json files
|
||||||
|
export_files.extend(export_path.glob("*.zip"))
|
||||||
|
export_files.extend(export_path.glob("*.json"))
|
||||||
|
|
||||||
|
return export_files
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load Claude export data and convert to text chunks."""
|
||||||
|
export_path = Path(args.export_path)
|
||||||
|
|
||||||
|
if not export_path.exists():
|
||||||
|
print(f"Claude export path not found: {export_path}")
|
||||||
|
print(
|
||||||
|
"Please ensure you have exported your Claude data and placed it in the correct location."
|
||||||
|
)
|
||||||
|
print("\nTo export your Claude data:")
|
||||||
|
print("1. Open Claude in your browser")
|
||||||
|
print("2. Look for export/download options in settings or conversation menu")
|
||||||
|
print("3. Download the conversation data (usually in JSON format)")
|
||||||
|
print("4. Place the file/directory at the specified path")
|
||||||
|
print(
|
||||||
|
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find export files
|
||||||
|
export_files = self._find_claude_exports(export_path)
|
||||||
|
|
||||||
|
if not export_files:
|
||||||
|
print(f"No Claude export files (.json or .zip) found in: {export_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(export_files)} Claude export files")
|
||||||
|
|
||||||
|
# Create reader with appropriate settings
|
||||||
|
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||||
|
reader = ClaudeReader(concatenate_conversations=concatenate)
|
||||||
|
|
||||||
|
# Process each export file
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, export_file in enumerate(export_files):
|
||||||
|
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per file
|
||||||
|
max_per_file = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_file = remaining
|
||||||
|
|
||||||
|
# Load conversations
|
||||||
|
documents = reader.load_data(
|
||||||
|
claude_export_path=str(export_file),
|
||||||
|
max_count=max_per_file,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} conversations from this file")
|
||||||
|
else:
|
||||||
|
print(f"No conversations loaded from {export_file}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No conversations found to process!")
|
||||||
|
print("\nTroubleshooting:")
|
||||||
|
print("- Ensure the export file is a valid Claude export")
|
||||||
|
print("- Check that the JSON file contains conversation data")
|
||||||
|
print("- Try using a different export format or method")
|
||||||
|
print("- Check Claude's documentation for current export procedures")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||||
|
print("Now starting to split into text chunks... this may take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for Claude RAG
|
||||||
|
print("\n🤖 Claude RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What did I ask Claude about Python programming?'")
|
||||||
|
print("- 'Show me conversations about machine learning'")
|
||||||
|
print("- 'Find discussions about code optimization'")
|
||||||
|
print("- 'What advice did Claude give me about software design?'")
|
||||||
|
print("- 'Search for conversations about debugging techniques'")
|
||||||
|
print("\nTo get started:")
|
||||||
|
print("1. Export your Claude conversation data")
|
||||||
|
print("2. Place the JSON/ZIP file in ./claude_export/")
|
||||||
|
print("3. Run this script to build your personal Claude knowledge base!")
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = ClaudeRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
211
apps/code_rag.py
Normal file
211
apps/code_rag.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||||
|
Specialized for code repositories with automatic language detection and
|
||||||
|
optimized chunking parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRAG(BaseRAGExample):
|
||||||
|
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Code",
|
||||||
|
description="Process and query code repositories with AST-aware chunking",
|
||||||
|
default_index_name="code_index",
|
||||||
|
)
|
||||||
|
# Override defaults for code-specific usage
|
||||||
|
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||||
|
self.max_items_default = -1 # Process all code files by default
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add code-specific arguments."""
|
||||||
|
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||||
|
|
||||||
|
code_group.add_argument(
|
||||||
|
"--repo-dir",
|
||||||
|
type=str,
|
||||||
|
default=".",
|
||||||
|
help="Code repository directory to index (default: current directory)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-extensions",
|
||||||
|
nargs="+",
|
||||||
|
default=list(CODE_EXTENSIONS.keys()),
|
||||||
|
help="File extensions to include (default: supported code extensions)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--exclude-dirs",
|
||||||
|
nargs="+",
|
||||||
|
default=[
|
||||||
|
".git",
|
||||||
|
"__pycache__",
|
||||||
|
"node_modules",
|
||||||
|
"venv",
|
||||||
|
".venv",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"target",
|
||||||
|
],
|
||||||
|
help="Directories to exclude from indexing",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--max-file-size",
|
||||||
|
type=int,
|
||||||
|
default=1000000, # 1MB
|
||||||
|
help="Maximum file size in bytes to process (default: 1MB)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--include-comments",
|
||||||
|
action="store_true",
|
||||||
|
help="Include comments in chunking (useful for documentation)",
|
||||||
|
)
|
||||||
|
code_group.add_argument(
|
||||||
|
"--preserve-imports",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Try to preserve import statements in chunks (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load code files and convert to AST-aware chunks."""
|
||||||
|
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||||
|
print(f"📁 Including extensions: {args.include_extensions}")
|
||||||
|
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||||
|
|
||||||
|
# Check if repository directory exists
|
||||||
|
repo_path = Path(args.repo_dir)
|
||||||
|
if not repo_path.exists():
|
||||||
|
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||||
|
|
||||||
|
# Load code files with filtering
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
"required_exts": args.include_extensions,
|
||||||
|
"exclude_hidden": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create exclusion filter
|
||||||
|
def file_filter(file_path: str) -> bool:
|
||||||
|
"""Filter out unwanted files and directories."""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# Check file size
|
||||||
|
try:
|
||||||
|
if path.stat().st_size > args.max_file_size:
|
||||||
|
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if in excluded directory
|
||||||
|
for exclude_dir in args.exclude_dirs:
|
||||||
|
if exclude_dir in path.parts:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load documents with file filtering
|
||||||
|
documents = SimpleDirectoryReader(
|
||||||
|
args.repo_dir,
|
||||||
|
file_extractor=None, # Use default extractors
|
||||||
|
**reader_kwargs,
|
||||||
|
).load_data(show_progress=True)
|
||||||
|
|
||||||
|
# Apply custom filtering
|
||||||
|
filtered_docs = []
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_filter(file_path):
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
documents = filtered_docs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error loading code files: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(
|
||||||
|
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"✅ Loaded {len(documents)} code files")
|
||||||
|
|
||||||
|
# Show breakdown by language/extension
|
||||||
|
ext_counts = {}
|
||||||
|
for doc in documents:
|
||||||
|
file_path = doc.metadata.get("file_path", "")
|
||||||
|
if file_path:
|
||||||
|
ext = Path(file_path).suffix.lower()
|
||||||
|
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||||
|
|
||||||
|
print("📊 Files by extension:")
|
||||||
|
for ext, count in sorted(ext_counts.items()):
|
||||||
|
print(f" {ext}: {count} files")
|
||||||
|
|
||||||
|
# Use AST-aware chunking by default for code
|
||||||
|
print(
|
||||||
|
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=256, # Fallback for non-code files
|
||||||
|
chunk_overlap=64,
|
||||||
|
use_ast_chunking=True, # Always use AST for code RAG
|
||||||
|
ast_chunk_size=args.ast_chunk_size,
|
||||||
|
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||||
|
code_file_extensions=args.include_extensions,
|
||||||
|
ast_fallback_traditional=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for code RAG
|
||||||
|
print("\n💻 Code RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'How does the embedding computation work?'")
|
||||||
|
print("- 'What are the main classes in this codebase?'")
|
||||||
|
print("- 'Show me the search implementation'")
|
||||||
|
print("- 'How is error handling implemented?'")
|
||||||
|
print("- 'What design patterns are used?'")
|
||||||
|
print("- 'Explain the chunking logic'")
|
||||||
|
print("\n🚀 Features:")
|
||||||
|
print("- ✅ AST-aware chunking preserves code structure")
|
||||||
|
print("- ✅ Automatic language detection")
|
||||||
|
print("- ✅ Smart filtering of large files and common excludes")
|
||||||
|
print("- ✅ Optimized for code understanding")
|
||||||
|
print("\nUsage examples:")
|
||||||
|
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||||
|
print(
|
||||||
|
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||||
|
)
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = CodeRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
131
apps/document_rag.py
Normal file
131
apps/document_rag.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
Document RAG example using the unified interface.
|
||||||
|
Supports PDF, TXT, MD, and other document formats.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentRAG(BaseRAGExample):
|
||||||
|
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Document",
|
||||||
|
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||||
|
default_index_name="test_doc_files",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add document-specific arguments."""
|
||||||
|
doc_group = parser.add_argument_group("Document Parameters")
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
type=str,
|
||||||
|
default="data",
|
||||||
|
help="Directory containing documents to index (default: data)",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--file-types",
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||||
|
)
|
||||||
|
doc_group.add_argument(
|
||||||
|
"--enable-code-chunking",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable AST-aware chunking for code files in the data directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load documents and convert to text chunks."""
|
||||||
|
print(f"Loading documents from: {args.data_dir}")
|
||||||
|
if args.file_types:
|
||||||
|
print(f"Filtering by file types: {args.file_types}")
|
||||||
|
else:
|
||||||
|
print("Processing all supported file types")
|
||||||
|
|
||||||
|
# Check if data directory exists
|
||||||
|
data_path = Path(args.data_dir)
|
||||||
|
if not data_path.exists():
|
||||||
|
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
reader_kwargs = {
|
||||||
|
"recursive": True,
|
||||||
|
"encoding": "utf-8",
|
||||||
|
}
|
||||||
|
if args.file_types:
|
||||||
|
reader_kwargs["required_exts"] = args.file_types
|
||||||
|
|
||||||
|
documents = SimpleDirectoryReader(args.data_dir, **reader_kwargs).load_data(
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} documents")
|
||||||
|
|
||||||
|
# Determine chunking strategy
|
||||||
|
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||||
|
|
||||||
|
if use_ast:
|
||||||
|
print("Using AST-aware chunking for code files")
|
||||||
|
|
||||||
|
# Convert to text chunks with optional AST support
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
use_ast_chunking=use_ast,
|
||||||
|
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||||
|
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||||
|
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||||
|
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||||
|
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Example queries for document RAG
|
||||||
|
print("\n📄 Document RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What are the main techniques LEANN uses?'")
|
||||||
|
print("- 'What is the technique DLPM?'")
|
||||||
|
print("- 'Who does Elizabeth Bennet marry?'")
|
||||||
|
print(
|
||||||
|
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||||
|
)
|
||||||
|
print("\n🚀 NEW: Code-aware chunking available!")
|
||||||
|
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||||
|
print("- Supports Python, Java, C#, TypeScript files")
|
||||||
|
print("- Better semantic understanding of code structure")
|
||||||
|
print("\nOr run without --query for interactive mode\n")
|
||||||
|
|
||||||
|
rag = DocumentRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
167
apps/email_data/LEANN_email_reader.py
Normal file
167
apps/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import email
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Recursively find all 'Messages' directories under the given root.
|
||||||
|
Returns a list of Path objects.
|
||||||
|
"""
|
||||||
|
if root is None:
|
||||||
|
# Auto-detect user's mail path
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
root = os.path.join(home_dir, "Library", "Mail")
|
||||||
|
|
||||||
|
messages_dirs = []
|
||||||
|
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||||
|
if os.path.basename(dirpath) == "Messages":
|
||||||
|
messages_dirs.append(Path(dirpath))
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
|
||||||
|
class EmlxReader(BaseReader):
|
||||||
|
"""
|
||||||
|
Apple Mail .emlx file reader with embedded metadata.
|
||||||
|
|
||||||
|
Reads individual .emlx files from Apple Mail's storage format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, include_html: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_html: Whether to include HTML content in the email body (default: False)
|
||||||
|
"""
|
||||||
|
self.include_html = include_html
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load data from the input directory containing .emlx files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Directory containing .emlx files
|
||||||
|
**load_kwargs:
|
||||||
|
max_count (int): Maximum amount of messages to read.
|
||||||
|
"""
|
||||||
|
docs: list[Document] = []
|
||||||
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
|
count = 0
|
||||||
|
total_files = 0
|
||||||
|
successful_files = 0
|
||||||
|
failed_files = 0
|
||||||
|
|
||||||
|
print(f"Starting to process directory: {input_dir}")
|
||||||
|
|
||||||
|
# Walk through the directory recursively
|
||||||
|
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||||
|
# Skip hidden directories
|
||||||
|
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||||
|
|
||||||
|
for filename in filenames:
|
||||||
|
# Check if we've reached the max count (skip if max_count == -1)
|
||||||
|
if max_count > 0 and count >= max_count:
|
||||||
|
break
|
||||||
|
|
||||||
|
if filename.endswith(".emlx"):
|
||||||
|
total_files += 1
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
# Read the .emlx file
|
||||||
|
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# .emlx files have a length prefix followed by the email content
|
||||||
|
# The first line contains the length, followed by the email
|
||||||
|
lines = content.split("\n", 1)
|
||||||
|
if len(lines) >= 2:
|
||||||
|
email_content = lines[1]
|
||||||
|
|
||||||
|
# Parse the email using Python's email module
|
||||||
|
try:
|
||||||
|
msg = email.message_from_string(email_content)
|
||||||
|
|
||||||
|
# Extract email metadata
|
||||||
|
subject = msg.get("Subject", "No Subject")
|
||||||
|
from_addr = msg.get("From", "Unknown")
|
||||||
|
to_addr = msg.get("To", "Unknown")
|
||||||
|
date = msg.get("Date", "Unknown")
|
||||||
|
|
||||||
|
# Extract email body
|
||||||
|
body = ""
|
||||||
|
if msg.is_multipart():
|
||||||
|
for part in msg.walk():
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/plain"
|
||||||
|
or part.get_content_type() == "text/html"
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
part.get_content_type() == "text/html"
|
||||||
|
and not self.include_html
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
payload = part.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body += payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding payload: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
payload = msg.get_payload(decode=True)
|
||||||
|
if payload:
|
||||||
|
body = payload.decode("utf-8", errors="ignore")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error decoding single part payload: {e}")
|
||||||
|
body = ""
|
||||||
|
|
||||||
|
# Only create document if we have some content
|
||||||
|
if body.strip() or subject != "No Subject":
|
||||||
|
# Create document content with metadata embedded in text
|
||||||
|
doc_content = f"""
|
||||||
|
[File]: {filename}
|
||||||
|
[From]: {from_addr}
|
||||||
|
[To]: {to_addr}
|
||||||
|
[Subject]: {subject}
|
||||||
|
[Date]: {date}
|
||||||
|
[EMAIL BODY Start]:
|
||||||
|
{body}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# No separate metadata - everything is in the text
|
||||||
|
doc = Document(text=doc_content, metadata={})
|
||||||
|
docs.append(doc)
|
||||||
|
count += 1
|
||||||
|
successful_files += 1
|
||||||
|
|
||||||
|
# Print first few successful files for debugging
|
||||||
|
if successful_files <= 3:
|
||||||
|
print(
|
||||||
|
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error parsing email from {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_files += 1
|
||||||
|
if failed_files <= 5: # Only print first few errors
|
||||||
|
print(f"Error reading file {filepath}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Processing summary:")
|
||||||
|
print(f" Total .emlx files found: {total_files}")
|
||||||
|
print(f" Successfully loaded: {successful_files}")
|
||||||
|
print(f" Failed to load: {failed_files}")
|
||||||
|
print(f" Final documents: {len(docs)}")
|
||||||
|
|
||||||
|
return docs
|
||||||
@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
from fsspec import AbstractFileSystem
|
|
||||||
|
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.core.schema import Document
|
from llama_index.core.schema import Document
|
||||||
|
|
||||||
@@ -27,11 +27,7 @@ class MboxReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MESSAGE_FORMAT: str = (
|
DEFAULT_MESSAGE_FORMAT: str = (
|
||||||
"Date: {_date}\n"
|
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||||
"From: {_from}\n"
|
|
||||||
"To: {_to}\n"
|
|
||||||
"Subject: {_subject}\n"
|
|
||||||
"Content: {_content}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -45,9 +41,7 @@ class MboxReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
from bs4 import BeautifulSoup # noqa
|
from bs4 import BeautifulSoup # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.max_count = max_count
|
self.max_count = max_count
|
||||||
@@ -56,9 +50,9 @@ class MboxReader(BaseReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
file: Path,
|
file: Path,
|
||||||
extra_info: Optional[Dict] = None,
|
extra_info: dict | None = None,
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Parse file into string."""
|
"""Parse file into string."""
|
||||||
# Import required libraries
|
# Import required libraries
|
||||||
import mailbox
|
import mailbox
|
||||||
@@ -74,7 +68,7 @@ class MboxReader(BaseReader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
results: List[str] = []
|
results: list[str] = []
|
||||||
# Load file using mailbox
|
# Load file using mailbox
|
||||||
bytes_parser = BytesParser(policy=default).parse
|
bytes_parser = BytesParser(policy=default).parse
|
||||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||||
@@ -134,12 +128,12 @@ class EmlxMboxReader(MboxReader):
|
|||||||
def load_data(
|
def load_data(
|
||||||
self,
|
self,
|
||||||
directory: Path,
|
directory: Path,
|
||||||
extra_info: Optional[Dict] = None,
|
extra_info: dict | None = None,
|
||||||
fs: Optional[AbstractFileSystem] = None,
|
fs: AbstractFileSystem | None = None,
|
||||||
) -> List[Document]:
|
) -> list[Document]:
|
||||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||||
import tempfile
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
if fs:
|
if fs:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -156,18 +150,18 @@ class EmlxMboxReader(MboxReader):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# Create a temporary mbox file
|
# Create a temporary mbox file
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||||
temp_mbox_path = temp_mbox.name
|
temp_mbox_path = temp_mbox.name
|
||||||
|
|
||||||
# Convert .emlx files to mbox format
|
# Convert .emlx files to mbox format
|
||||||
for emlx_file in emlx_files:
|
for emlx_file in emlx_files:
|
||||||
try:
|
try:
|
||||||
# Read the .emlx file
|
# Read the .emlx file
|
||||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
# .emlx format: first line is length, rest is email content
|
# .emlx format: first line is length, rest is email content
|
||||||
lines = content.split('\n', 1)
|
lines = content.split("\n", 1)
|
||||||
if len(lines) >= 2:
|
if len(lines) >= 2:
|
||||||
email_content = lines[1] # Skip the length line
|
email_content = lines[1] # Skip the length line
|
||||||
|
|
||||||
@@ -188,5 +182,5 @@ class EmlxMboxReader(MboxReader):
|
|||||||
# Clean up temporary file
|
# Clean up temporary file
|
||||||
try:
|
try:
|
||||||
os.unlink(temp_mbox_path)
|
os.unlink(temp_mbox_path)
|
||||||
except:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
157
apps/email_rag.py
Normal file
157
apps/email_rag.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
Email RAG example using the unified interface.
|
||||||
|
Supports Apple Mail on macOS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
from chunking import create_text_chunks
|
||||||
|
|
||||||
|
from .email_data.LEANN_email_reader import EmlxReader
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRAG(BaseRAGExample):
|
||||||
|
"""RAG example for Apple Mail processing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Process all emails by default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="Email",
|
||||||
|
description="Process and query Apple Mail emails with LEANN",
|
||||||
|
default_index_name="mail_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add email-specific arguments."""
|
||||||
|
email_group = parser.add_argument_group("Email Parameters")
|
||||||
|
email_group.add_argument(
|
||||||
|
"--mail-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||||
|
)
|
||||||
|
email_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_mail_directories(self) -> list[Path]:
|
||||||
|
"""Auto-detect all Apple Mail directories."""
|
||||||
|
mail_base = Path.home() / "Library" / "Mail"
|
||||||
|
if not mail_base.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Find all Messages directories
|
||||||
|
messages_dirs = []
|
||||||
|
for item in mail_base.rglob("Messages"):
|
||||||
|
if item.is_dir():
|
||||||
|
messages_dirs.append(item)
|
||||||
|
|
||||||
|
return messages_dirs
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load emails and convert to text chunks."""
|
||||||
|
# Determine mail directories
|
||||||
|
if args.mail_path:
|
||||||
|
messages_dirs = [Path(args.mail_path)]
|
||||||
|
else:
|
||||||
|
print("Auto-detecting Apple Mail directories...")
|
||||||
|
messages_dirs = self._find_mail_directories()
|
||||||
|
|
||||||
|
if not messages_dirs:
|
||||||
|
print("No Apple Mail directories found!")
|
||||||
|
print("Please specify --mail-path manually")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Found {len(messages_dirs)} mail directories")
|
||||||
|
|
||||||
|
# Create reader
|
||||||
|
reader = EmlxReader(include_html=args.include_html)
|
||||||
|
|
||||||
|
# Process each directory
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, messages_dir in enumerate(messages_dirs):
|
||||||
|
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Count emlx files
|
||||||
|
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||||
|
print(f"Found {len(emlx_files)} email files")
|
||||||
|
|
||||||
|
# Apply max_items limit per directory
|
||||||
|
max_per_dir = -1 # Default to process all
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_dir = remaining
|
||||||
|
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||||
|
|
||||||
|
# Load emails - fix the parameter passing
|
||||||
|
documents = reader.load_data(
|
||||||
|
input_dir=str(messages_dir),
|
||||||
|
max_count=max_per_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
print(f"Processed {len(documents)} emails from this directory")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {messages_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No emails found to process!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
# Email reader uses chunk_overlap=25 as in original
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||||
|
print(" Windows/Linux support coming soon!\n")
|
||||||
|
|
||||||
|
# Example queries for email RAG
|
||||||
|
print("\n📧 Email RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'What did my boss say about deadlines?'")
|
||||||
|
print("- 'Find emails about travel expenses'")
|
||||||
|
print("- 'Show me emails from last month about the project'")
|
||||||
|
print("- 'What food did I order from DoorDash?'")
|
||||||
|
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||||
|
|
||||||
|
rag = EmailRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
from .history import ChromeHistoryReader
|
from .history import ChromeHistoryReader
|
||||||
|
|
||||||
__all__ = ['ChromeHistoryReader']
|
__all__ = ["ChromeHistoryReader"]
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
import sqlite3
|
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
class ChromeHistoryReader(BaseReader):
|
class ChromeHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||||
@@ -17,7 +19,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load Chrome history data from the default Chrome profile location.
|
Load Chrome history data from the default Chrome profile location.
|
||||||
|
|
||||||
@@ -27,13 +29,15 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
max_count (int): Maximum amount of history entries to read.
|
max_count (int): Maximum amount of history entries to read.
|
||||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||||
|
|
||||||
# Default Chrome profile path on macOS
|
# Default Chrome profile path on macOS
|
||||||
if chrome_profile_path is None:
|
if chrome_profile_path is None:
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
|
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
@@ -70,7 +74,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||||
|
|
||||||
# Create document content with metadata embedded in text
|
# Create document content with metadata embedded in text
|
||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
@@ -82,7 +86,7 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
||||||
# if len(title) > 150:
|
# if len(title) > 150:
|
||||||
# print(f"Title is too long: {title}")
|
# print(f"Title is too long: {title}")
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
@@ -93,12 +97,17 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading Chrome history: {e}")
|
print(f"Error reading Chrome history: {e}")
|
||||||
|
# add you may need to close your browser to make the database file available
|
||||||
|
# also highlight in red
|
||||||
|
print(
|
||||||
|
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||||
|
)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_chrome_profiles() -> List[Path]:
|
def find_chrome_profiles() -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find all Chrome profile directories.
|
Find all Chrome profile directories.
|
||||||
|
|
||||||
@@ -124,7 +133,9 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
return profile_dirs
|
return profile_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
def export_history_to_file(
|
||||||
|
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export Chrome history to a text file using the same SQL query format.
|
Export Chrome history to a text file using the same SQL query format.
|
||||||
|
|
||||||
@@ -132,7 +143,9 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
output_file: Path to the output file
|
output_file: Path to the output file
|
||||||
max_count: Maximum number of entries to export
|
max_count: Maximum number of entries to export
|
||||||
"""
|
"""
|
||||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
chrome_profile_path = os.path.expanduser(
|
||||||
|
"~/Library/Application Support/Google/Chrome/Default"
|
||||||
|
)
|
||||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||||
|
|
||||||
if not os.path.exists(history_db_path):
|
if not os.path.exists(history_db_path):
|
||||||
@@ -159,10 +172,12 @@ class ChromeHistoryReader(BaseReader):
|
|||||||
cursor.execute(query, (max_count,))
|
cursor.execute(query, (max_count,))
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
f.write(
|
||||||
|
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||||
|
)
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||||
@@ -2,13 +2,14 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from llama_index.core import Document
|
from llama_index.core import Document
|
||||||
from llama_index.core.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
class WeChatHistoryReader(BaseReader):
|
class WeChatHistoryReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
@@ -43,10 +44,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||||
if not wechattweak_path.exists():
|
if not wechattweak_path.exists():
|
||||||
print("Downloading WeChatTweak CLI...")
|
print("Downloading WeChatTweak CLI...")
|
||||||
subprocess.run([
|
subprocess.run(
|
||||||
"curl", "-L", "-o", str(wechattweak_path),
|
[
|
||||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
"curl",
|
||||||
], check=True)
|
"-L",
|
||||||
|
"-o",
|
||||||
|
str(wechattweak_path),
|
||||||
|
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Make executable
|
# Make executable
|
||||||
wechattweak_path.chmod(0o755)
|
wechattweak_path.chmod(0o755)
|
||||||
@@ -73,16 +80,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
def check_api_available(self) -> bool:
|
def check_api_available(self) -> bool:
|
||||||
"""Check if WeChatTweak API is available."""
|
"""Check if WeChatTweak API is available."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
||||||
], capture_output=True, text=True, timeout=5)
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
return result.returncode == 0 and result.stdout.strip()
|
return result.returncode == 0 and result.stdout.strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_readable_text(self, content: str) -> str:
|
def _extract_readable_text(self, content: str) -> str:
|
||||||
"""
|
"""
|
||||||
Extract readable text from message content, removing XML and system messages.
|
Extract readable text from message content, removing XML and system messages.
|
||||||
@@ -100,14 +107,14 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Extract text from dictionary structure
|
# Extract text from dictionary structure
|
||||||
text_parts = []
|
text_parts = []
|
||||||
if 'title' in content:
|
if "title" in content:
|
||||||
text_parts.append(str(content['title']))
|
text_parts.append(str(content["title"]))
|
||||||
if 'quoted' in content:
|
if "quoted" in content:
|
||||||
text_parts.append(str(content['quoted']))
|
text_parts.append(str(content["quoted"]))
|
||||||
if 'content' in content:
|
if "content" in content:
|
||||||
text_parts.append(str(content['content']))
|
text_parts.append(str(content["content"]))
|
||||||
if 'text' in content:
|
if "text" in content:
|
||||||
text_parts.append(str(content['text']))
|
text_parts.append(str(content["text"]))
|
||||||
|
|
||||||
if text_parts:
|
if text_parts:
|
||||||
return " | ".join(text_parts)
|
return " | ".join(text_parts)
|
||||||
@@ -120,11 +127,11 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Remove common prefixes like "wxid_xxx:\n"
|
# Remove common prefixes like "wxid_xxx:\n"
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
# If it's just XML or system message, return empty
|
# If it's just XML or system message, return empty
|
||||||
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return clean_content.strip()
|
return clean_content.strip()
|
||||||
@@ -145,9 +152,9 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
# Handle dictionary content
|
# Handle dictionary content
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
# Check if dict has any readable text fields
|
# Check if dict has any readable text fields
|
||||||
text_fields = ['title', 'quoted', 'content', 'text']
|
text_fields = ["title", "quoted", "content", "text"]
|
||||||
for field in text_fields:
|
for field in text_fields:
|
||||||
if field in content and content[field]:
|
if content.get(field):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -156,42 +163,47 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip image messages (contain XML with img tags)
|
# Skip image messages (contain XML with img tags)
|
||||||
if '<img' in content and 'cdnurl' in content:
|
if "<img" in content and "cdnurl" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip emoji messages (contain emoji XML tags)
|
# Skip emoji messages (contain emoji XML tags)
|
||||||
if '<emoji' in content and 'productid' in content:
|
if "<emoji" in content and "productid" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip voice messages
|
# Skip voice messages
|
||||||
if '<voice' in content:
|
if "<voice" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip video messages
|
# Skip video messages
|
||||||
if '<video' in content:
|
if "<video" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip file messages
|
# Skip file messages
|
||||||
if '<appmsg' in content and 'appid' in content:
|
if "<appmsg" in content and "appid" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip system messages (like "recalled a message")
|
# Skip system messages (like "recalled a message")
|
||||||
if 'recalled a message' in content:
|
if "recalled a message" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if there's actual readable text (not just XML or system messages)
|
# Check if there's actual readable text (not just XML or system messages)
|
||||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||||
|
|
||||||
# If after cleaning we have meaningful text, consider it readable
|
# If after cleaning we have meaningful text, consider it readable
|
||||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
def _concatenate_messages(
|
||||||
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
self,
|
||||||
|
messages: list[dict],
|
||||||
|
max_length: int = 128,
|
||||||
|
time_window_minutes: int = 30,
|
||||||
|
overlap_messages: int = 0,
|
||||||
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Concatenate messages based on length and time rules.
|
Concatenate messages based on length and time rules.
|
||||||
|
|
||||||
@@ -214,12 +226,12 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
# Extract message info
|
# Extract message info
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
from_user = message.get('fromUser', '')
|
message.get("fromUser", "")
|
||||||
to_user = message.get('toUser', '')
|
message.get("toUser", "")
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
@@ -236,16 +248,24 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if time_diff_minutes > time_window_minutes:
|
if time_diff_minutes > time_window_minutes:
|
||||||
# Time gap too large, start new group
|
# Time gap too large, start new group
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
@@ -254,16 +274,24 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
message_length = len(readable_text)
|
message_length = len(readable_text)
|
||||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||||
# Current group would exceed max length, save it and start new
|
# Current group would exceed max length, save it and start new
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
# Keep last few messages for overlap
|
# Keep last few messages for overlap
|
||||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||||
current_group = current_group[-overlap_messages:]
|
current_group = current_group[-overlap_messages:]
|
||||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
current_length = sum(
|
||||||
|
len(
|
||||||
|
self._extract_readable_text(msg.get("content", ""))
|
||||||
|
or msg.get("message", "")
|
||||||
|
)
|
||||||
|
for msg in current_group
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
current_group = []
|
current_group = []
|
||||||
current_length = 0
|
current_length = 0
|
||||||
@@ -275,16 +303,18 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
|
|
||||||
# Add the last group if it exists
|
# Add the last group if it exists
|
||||||
if current_group:
|
if current_group:
|
||||||
concatenated_groups.append({
|
concatenated_groups.append(
|
||||||
'messages': current_group,
|
{
|
||||||
'total_length': current_length,
|
"messages": current_group,
|
||||||
'start_time': current_group[0].get('createTime', 0),
|
"total_length": current_length,
|
||||||
'end_time': current_group[-1].get('createTime', 0)
|
"start_time": current_group[0].get("createTime", 0),
|
||||||
})
|
"end_time": current_group[-1].get("createTime", 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return concatenated_groups
|
return concatenated_groups
|
||||||
|
|
||||||
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
def _create_concatenated_content(self, message_group: dict, contact_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Create concatenated content from a group of messages.
|
Create concatenated content from a group of messages.
|
||||||
|
|
||||||
@@ -295,16 +325,16 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
Returns:
|
Returns:
|
||||||
Formatted concatenated content
|
Formatted concatenated content
|
||||||
"""
|
"""
|
||||||
messages = message_group['messages']
|
messages = message_group["messages"]
|
||||||
start_time = message_group['start_time']
|
start_time = message_group["start_time"]
|
||||||
end_time = message_group['end_time']
|
end_time = message_group["end_time"]
|
||||||
|
|
||||||
# Format timestamps
|
# Format timestamps
|
||||||
if start_time:
|
if start_time:
|
||||||
try:
|
try:
|
||||||
start_timestamp = datetime.fromtimestamp(start_time)
|
start_timestamp = datetime.fromtimestamp(start_time)
|
||||||
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
start_time_str = str(start_time)
|
start_time_str = str(start_time)
|
||||||
else:
|
else:
|
||||||
start_time_str = "Unknown"
|
start_time_str = "Unknown"
|
||||||
@@ -312,8 +342,8 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
if end_time:
|
if end_time:
|
||||||
try:
|
try:
|
||||||
end_timestamp = datetime.fromtimestamp(end_time)
|
end_timestamp = datetime.fromtimestamp(end_time)
|
||||||
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
end_time_str = str(end_time)
|
end_time_str = str(end_time)
|
||||||
else:
|
else:
|
||||||
end_time_str = "Unknown"
|
end_time_str = "Unknown"
|
||||||
@@ -321,10 +351,10 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
# Build concatenated message content
|
# Build concatenated message content
|
||||||
message_parts = []
|
message_parts = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Extract readable text
|
# Extract readable text
|
||||||
readable_text = self._extract_readable_text(content)
|
readable_text = self._extract_readable_text(content)
|
||||||
@@ -336,8 +366,8 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
# change to YYYY-MM-DD HH:MM:SS
|
# change to YYYY-MM-DD HH:MM:SS
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -351,7 +381,7 @@ class WeChatHistoryReader(BaseReader):
|
|||||||
doc_content = f"""
|
doc_content = f"""
|
||||||
Contact: {contact_name}
|
Contact: {contact_name}
|
||||||
Time Range: {start_time_str} - {end_time_str}
|
Time Range: {start_time_str} - {end_time_str}
|
||||||
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||||
|
|
||||||
{concatenated_text}
|
{concatenated_text}
|
||||||
"""
|
"""
|
||||||
@@ -361,7 +391,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
"""
|
"""
|
||||||
return doc_content, contact_name
|
return doc_content, contact_name
|
||||||
|
|
||||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
"""
|
"""
|
||||||
Load WeChat chat history data from exported JSON files.
|
Load WeChat chat history data from exported JSON files.
|
||||||
|
|
||||||
@@ -376,13 +406,13 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||||
"""
|
"""
|
||||||
docs: List[Document] = []
|
docs: list[Document] = []
|
||||||
max_count = load_kwargs.get('max_count', 1000)
|
max_count = load_kwargs.get("max_count", 1000)
|
||||||
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||||
include_non_text = load_kwargs.get('include_non_text', False)
|
include_non_text = load_kwargs.get("include_non_text", False)
|
||||||
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||||
max_length = load_kwargs.get('max_length', 1000)
|
max_length = load_kwargs.get("max_length", 1000)
|
||||||
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||||
|
|
||||||
# Default WeChat export path
|
# Default WeChat export path
|
||||||
if wechat_export_dir is None:
|
if wechat_export_dir is None:
|
||||||
@@ -403,7 +433,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, 'r', encoding='utf-8') as f:
|
with open(json_file, encoding="utf-8") as f:
|
||||||
chat_data = json.load(f)
|
chat_data = json.load(f)
|
||||||
|
|
||||||
# Extract contact name from filename
|
# Extract contact name from filename
|
||||||
@@ -414,7 +444,7 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
readable_messages = []
|
readable_messages = []
|
||||||
for message in chat_data:
|
for message in chat_data:
|
||||||
try:
|
try:
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
if not include_non_text and not self._is_text_message(content):
|
if not include_non_text and not self._is_text_message(content):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -430,9 +460,9 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
# Concatenate messages based on rules
|
# Concatenate messages based on rules
|
||||||
message_groups = self._concatenate_messages(
|
message_groups = self._concatenate_messages(
|
||||||
readable_messages,
|
readable_messages,
|
||||||
max_length=-1,
|
max_length=max_length,
|
||||||
time_window_minutes=-1,
|
time_window_minutes=time_window_minutes,
|
||||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
overlap_messages=0, # No overlap between groups
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create documents from concatenated groups
|
# Create documents from concatenated groups
|
||||||
@@ -440,12 +470,19 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
doc_content, contact_name = self._create_concatenated_content(
|
||||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
message_group, contact_name
|
||||||
|
)
|
||||||
|
doc = Document(
|
||||||
|
text=doc_content,
|
||||||
|
metadata={"contact_name": contact_name},
|
||||||
|
)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
print(
|
||||||
|
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Original single-message processing
|
# Original single-message processing
|
||||||
@@ -454,12 +491,12 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Extract message information
|
# Extract message information
|
||||||
from_user = message.get('fromUser', '')
|
message.get("fromUser", "")
|
||||||
to_user = message.get('toUser', '')
|
message.get("toUser", "")
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||||
|
|
||||||
# Handle content that might be dict or string
|
# Handle content that might be dict or string
|
||||||
try:
|
try:
|
||||||
@@ -480,8 +517,8 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -495,7 +532,9 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Create document with embedded metadata
|
# Create document with embedded metadata
|
||||||
doc = Document(text=doc_content, metadata={})
|
doc = Document(
|
||||||
|
text=doc_content, metadata={"contact_name": contact_name}
|
||||||
|
)
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
@@ -512,7 +551,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
return docs
|
return docs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_wechat_export_dirs() -> List[Path]:
|
def find_wechat_export_dirs() -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find all WeChat export directories.
|
Find all WeChat export directories.
|
||||||
|
|
||||||
@@ -523,10 +562,10 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
|
|
||||||
# Look for common export directory names
|
# Look for common export directory names
|
||||||
possible_dirs = [
|
possible_dirs = [
|
||||||
Path("./wechat_export_test"),
|
|
||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
|
Path("./wechat_export_direct"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export")
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir in possible_dirs:
|
for export_dir in possible_dirs:
|
||||||
@@ -534,13 +573,20 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
json_files = list(export_dir.glob("*.json"))
|
json_files = list(export_dir.glob("*.json"))
|
||||||
if json_files:
|
if json_files:
|
||||||
export_dirs.append(export_dir)
|
export_dirs.append(export_dir)
|
||||||
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
print(
|
||||||
|
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||||
return export_dirs
|
return export_dirs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
def export_chat_to_file(
|
||||||
|
output_file: str = "wechat_chat_export.txt",
|
||||||
|
max_count: int = 1000,
|
||||||
|
export_dir: str | None = None,
|
||||||
|
include_non_text: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history to a text file.
|
Export WeChat chat history to a text file.
|
||||||
|
|
||||||
@@ -560,14 +606,14 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
try:
|
try:
|
||||||
json_files = list(Path(export_dir).glob("*.json"))
|
json_files = list(Path(export_dir).glob("*.json"))
|
||||||
|
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
count = 0
|
count = 0
|
||||||
for json_file in json_files:
|
for json_file in json_files:
|
||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(json_file, 'r', encoding='utf-8') as json_f:
|
with open(json_file, encoding="utf-8") as json_f:
|
||||||
chat_data = json.load(json_f)
|
chat_data = json.load(json_f)
|
||||||
|
|
||||||
contact_name = json_file.stem
|
contact_name = json_file.stem
|
||||||
@@ -577,10 +623,10 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if count >= max_count and max_count > 0:
|
if count >= max_count and max_count > 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
from_user = message.get('fromUser', '')
|
from_user = message.get("fromUser", "")
|
||||||
content = message.get('content', '')
|
content = message.get("content", "")
|
||||||
message_text = message.get('message', '')
|
message_text = message.get("message", "")
|
||||||
create_time = message.get('createTime', 0)
|
create_time = message.get("createTime", 0)
|
||||||
|
|
||||||
# Skip non-text messages unless requested
|
# Skip non-text messages unless requested
|
||||||
if not include_non_text:
|
if not include_non_text:
|
||||||
@@ -595,8 +641,8 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if create_time:
|
if create_time:
|
||||||
try:
|
try:
|
||||||
timestamp = datetime.fromtimestamp(create_time)
|
timestamp = datetime.fromtimestamp(create_time)
|
||||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except (ValueError, OSError):
|
||||||
time_str = str(create_time)
|
time_str = str(create_time)
|
||||||
else:
|
else:
|
||||||
time_str = "Unknown"
|
time_str = "Unknown"
|
||||||
@@ -613,7 +659,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error exporting WeChat chat history: {e}")
|
print(f"Error exporting WeChat chat history: {e}")
|
||||||
|
|
||||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
||||||
"""
|
"""
|
||||||
Export WeChat chat history using wechat-exporter tool.
|
Export WeChat chat history using wechat-exporter tool.
|
||||||
|
|
||||||
@@ -642,16 +688,21 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||||
if requirements_file.exists():
|
if requirements_file.exists():
|
||||||
print("Installing wechat-exporter requirements...")
|
print("Installing wechat-exporter requirements...")
|
||||||
subprocess.run([
|
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
||||||
"uv", "pip", "install", "-r", str(requirements_file)
|
|
||||||
], check=True)
|
|
||||||
|
|
||||||
# Run the export command
|
# Run the export command
|
||||||
print("Running wechat-exporter...")
|
print("Running wechat-exporter...")
|
||||||
result = subprocess.run([
|
result = subprocess.run(
|
||||||
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
[
|
||||||
"export-all", str(export_path)
|
sys.executable,
|
||||||
], capture_output=True, text=True, check=True)
|
str(self.wechat_exporter_dir / "main.py"),
|
||||||
|
"export-all",
|
||||||
|
str(export_path),
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
|
||||||
print("Export command output:")
|
print("Export command output:")
|
||||||
print(result.stdout)
|
print(result.stdout)
|
||||||
@@ -662,7 +713,9 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
# Check if export was successful
|
# Check if export was successful
|
||||||
if export_path.exists() and any(export_path.glob("*.json")):
|
if export_path.exists() and any(export_path.glob("*.json")):
|
||||||
json_files = list(export_path.glob("*.json"))
|
json_files = list(export_path.glob("*.json"))
|
||||||
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
print(
|
||||||
|
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
||||||
|
)
|
||||||
return export_path
|
return export_path
|
||||||
else:
|
else:
|
||||||
print("Export completed but no JSON files found")
|
print("Export completed but no JSON files found")
|
||||||
@@ -678,7 +731,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
||||||
"""
|
"""
|
||||||
Find existing WeChat exports or create new ones.
|
Find existing WeChat exports or create new ones.
|
||||||
|
|
||||||
@@ -697,7 +750,7 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
Path("./wechat_export"),
|
Path("./wechat_export"),
|
||||||
Path("./wechat_export_direct"),
|
Path("./wechat_export_direct"),
|
||||||
Path("./wechat_chat_history"),
|
Path("./wechat_chat_history"),
|
||||||
Path("./chat_export")
|
Path("./chat_export"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for export_dir_path in possible_export_dirs:
|
for export_dir_path in possible_export_dirs:
|
||||||
@@ -714,6 +767,8 @@ Message: {readable_text if readable_text else message_text}
|
|||||||
if exported_path:
|
if exported_path:
|
||||||
export_dirs = [exported_path]
|
export_dirs = [exported_path]
|
||||||
else:
|
else:
|
||||||
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
print(
|
||||||
|
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
||||||
|
)
|
||||||
|
|
||||||
return export_dirs
|
return export_dirs
|
||||||
1
apps/imessage_data/__init__.py
Normal file
1
apps/imessage_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""iMessage data processing module."""
|
||||||
342
apps/imessage_data/imessage_reader.py
Normal file
342
apps/imessage_data/imessage_reader.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
iMessage data reader.
|
||||||
|
|
||||||
|
Reads and processes iMessage conversation data from the macOS Messages database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_index.core import Document
|
||||||
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class IMessageReader(BaseReader):
|
||||||
|
"""
|
||||||
|
iMessage data reader.
|
||||||
|
|
||||||
|
Reads iMessage conversation data from the macOS Messages database (chat.db).
|
||||||
|
Processes conversations into structured documents with metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Initialize.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||||
|
"""
|
||||||
|
self.concatenate_conversations = concatenate_conversations
|
||||||
|
|
||||||
|
def _get_default_chat_db_path(self) -> Path:
|
||||||
|
"""
|
||||||
|
Get the default path to the iMessage chat database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the chat.db file
|
||||||
|
"""
|
||||||
|
home = Path.home()
|
||||||
|
return home / "Library" / "Messages" / "chat.db"
|
||||||
|
|
||||||
|
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
|
||||||
|
"""
|
||||||
|
Convert Cocoa timestamp to readable format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted timestamp string
|
||||||
|
"""
|
||||||
|
if cocoa_timestamp == 0:
|
||||||
|
return "Unknown"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
|
||||||
|
# Convert to seconds and add to Unix epoch
|
||||||
|
cocoa_epoch = datetime(2001, 1, 1)
|
||||||
|
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
|
||||||
|
message_time = cocoa_epoch.timestamp() + unix_timestamp
|
||||||
|
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except (ValueError, OSError):
|
||||||
|
return "Unknown"
|
||||||
|
|
||||||
|
def _get_contact_name(self, handle_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a readable contact name from handle ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handle_id: The handle ID (phone number or email)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted contact name
|
||||||
|
"""
|
||||||
|
if not handle_id:
|
||||||
|
return "Unknown"
|
||||||
|
|
||||||
|
# Clean up phone numbers and emails for display
|
||||||
|
if "@" in handle_id:
|
||||||
|
return handle_id # Email address
|
||||||
|
elif handle_id.startswith("+"):
|
||||||
|
return handle_id # International phone number
|
||||||
|
else:
|
||||||
|
# Try to format as phone number
|
||||||
|
digits = "".join(filter(str.isdigit, handle_id))
|
||||||
|
if len(digits) == 10:
|
||||||
|
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
||||||
|
elif len(digits) == 11 and digits[0] == "1":
|
||||||
|
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
||||||
|
else:
|
||||||
|
return handle_id
|
||||||
|
|
||||||
|
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Read messages from the iMessage database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the chat.db file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message dictionaries
|
||||||
|
"""
|
||||||
|
if not db_path.exists():
|
||||||
|
print(f"iMessage database not found at: {db_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Connect to the database
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Query to get messages with chat and handle information
|
||||||
|
query = """
|
||||||
|
SELECT
|
||||||
|
m.ROWID as message_id,
|
||||||
|
m.text,
|
||||||
|
m.date,
|
||||||
|
m.is_from_me,
|
||||||
|
m.service,
|
||||||
|
c.chat_identifier,
|
||||||
|
c.display_name as chat_display_name,
|
||||||
|
h.id as handle_id,
|
||||||
|
c.ROWID as chat_id
|
||||||
|
FROM message m
|
||||||
|
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
|
||||||
|
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
|
||||||
|
LEFT JOIN handle h ON m.handle_id = h.ROWID
|
||||||
|
WHERE m.text IS NOT NULL AND m.text != ''
|
||||||
|
ORDER BY c.ROWID, m.date
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute(query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for row in rows:
|
||||||
|
(
|
||||||
|
message_id,
|
||||||
|
text,
|
||||||
|
date,
|
||||||
|
is_from_me,
|
||||||
|
service,
|
||||||
|
chat_identifier,
|
||||||
|
chat_display_name,
|
||||||
|
handle_id,
|
||||||
|
chat_id,
|
||||||
|
) = row
|
||||||
|
|
||||||
|
message = {
|
||||||
|
"message_id": message_id,
|
||||||
|
"text": text,
|
||||||
|
"timestamp": self._convert_cocoa_timestamp(date),
|
||||||
|
"is_from_me": bool(is_from_me),
|
||||||
|
"service": service or "iMessage",
|
||||||
|
"chat_identifier": chat_identifier or "Unknown",
|
||||||
|
"chat_display_name": chat_display_name or "Unknown Chat",
|
||||||
|
"handle_id": handle_id or "Unknown",
|
||||||
|
"contact_name": self._get_contact_name(handle_id or ""),
|
||||||
|
"chat_id": chat_id,
|
||||||
|
}
|
||||||
|
messages.append(message)
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
print(f"Found {len(messages)} messages in database")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
print(f"Error reading iMessage database: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error reading iMessage database: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
|
||||||
|
"""
|
||||||
|
Group messages by chat ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping chat_id to list of messages
|
||||||
|
"""
|
||||||
|
chats = {}
|
||||||
|
for message in messages:
|
||||||
|
chat_id = message["chat_id"]
|
||||||
|
if chat_id not in chats:
|
||||||
|
chats[chat_id] = []
|
||||||
|
chats[chat_id].append(message)
|
||||||
|
|
||||||
|
return chats
|
||||||
|
|
||||||
|
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
|
||||||
|
"""
|
||||||
|
Create concatenated content from chat messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: The chat ID
|
||||||
|
messages: List of messages in the chat
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Concatenated text content
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Get chat info from first message
|
||||||
|
first_msg = messages[0]
|
||||||
|
chat_name = first_msg["chat_display_name"]
|
||||||
|
chat_identifier = first_msg["chat_identifier"]
|
||||||
|
|
||||||
|
# Build message content
|
||||||
|
message_parts = []
|
||||||
|
for message in messages:
|
||||||
|
timestamp = message["timestamp"]
|
||||||
|
is_from_me = message["is_from_me"]
|
||||||
|
text = message["text"]
|
||||||
|
contact_name = message["contact_name"]
|
||||||
|
|
||||||
|
if is_from_me:
|
||||||
|
prefix = "[You]"
|
||||||
|
else:
|
||||||
|
prefix = f"[{contact_name}]"
|
||||||
|
|
||||||
|
if timestamp != "Unknown":
|
||||||
|
prefix += f" ({timestamp})"
|
||||||
|
|
||||||
|
message_parts.append(f"{prefix}: {text}")
|
||||||
|
|
||||||
|
concatenated_text = "\n\n".join(message_parts)
|
||||||
|
|
||||||
|
doc_content = f"""Chat: {chat_name}
|
||||||
|
Identifier: {chat_identifier}
|
||||||
|
Messages ({len(messages)} messages):
|
||||||
|
|
||||||
|
{concatenated_text}
|
||||||
|
"""
|
||||||
|
return doc_content
|
||||||
|
|
||||||
|
def _create_individual_content(self, message: dict) -> str:
|
||||||
|
"""
|
||||||
|
Create content for individual message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted message content
|
||||||
|
"""
|
||||||
|
timestamp = message["timestamp"]
|
||||||
|
is_from_me = message["is_from_me"]
|
||||||
|
text = message["text"]
|
||||||
|
contact_name = message["contact_name"]
|
||||||
|
chat_name = message["chat_display_name"]
|
||||||
|
|
||||||
|
sender = "You" if is_from_me else contact_name
|
||||||
|
|
||||||
|
return f"""Message from {sender} in chat "{chat_name}"
|
||||||
|
Time: {timestamp}
|
||||||
|
Content: {text}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||||
|
"""
|
||||||
|
Load iMessage data and return as documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dir: Optional path to directory containing chat.db file.
|
||||||
|
If not provided, uses default macOS location.
|
||||||
|
**load_kwargs: Additional arguments (unused)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Document objects containing iMessage data
|
||||||
|
"""
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
# Determine database path
|
||||||
|
if input_dir:
|
||||||
|
db_path = Path(input_dir) / "chat.db"
|
||||||
|
else:
|
||||||
|
db_path = self._get_default_chat_db_path()
|
||||||
|
|
||||||
|
print(f"Reading iMessage database from: {db_path}")
|
||||||
|
|
||||||
|
# Read messages from database
|
||||||
|
messages = self._read_messages_from_db(db_path)
|
||||||
|
if not messages:
|
||||||
|
return docs
|
||||||
|
|
||||||
|
if self.concatenate_conversations:
|
||||||
|
# Group messages by chat and create concatenated documents
|
||||||
|
chats = self._group_messages_by_chat(messages)
|
||||||
|
|
||||||
|
for chat_id, chat_messages in chats.items():
|
||||||
|
if not chat_messages:
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = self._create_concatenated_content(chat_id, chat_messages)
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
first_msg = chat_messages[0]
|
||||||
|
last_msg = chat_messages[-1]
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"source": "iMessage",
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"chat_name": first_msg["chat_display_name"],
|
||||||
|
"chat_identifier": first_msg["chat_identifier"],
|
||||||
|
"message_count": len(chat_messages),
|
||||||
|
"first_message_date": first_msg["timestamp"],
|
||||||
|
"last_message_date": last_msg["timestamp"],
|
||||||
|
"participants": list(
|
||||||
|
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Create individual documents for each message
|
||||||
|
for message in messages:
|
||||||
|
content = self._create_individual_content(message)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"source": "iMessage",
|
||||||
|
"message_id": message["message_id"],
|
||||||
|
"chat_id": message["chat_id"],
|
||||||
|
"chat_name": message["chat_display_name"],
|
||||||
|
"chat_identifier": message["chat_identifier"],
|
||||||
|
"timestamp": message["timestamp"],
|
||||||
|
"is_from_me": message["is_from_me"],
|
||||||
|
"contact_name": message["contact_name"],
|
||||||
|
"service": message["service"],
|
||||||
|
}
|
||||||
|
|
||||||
|
doc = Document(text=content, metadata=metadata)
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
print(f"Created {len(docs)} documents from iMessage data")
|
||||||
|
return docs
|
||||||
125
apps/imessage_rag.py
Normal file
125
apps/imessage_rag.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""
|
||||||
|
iMessage RAG Example.
|
||||||
|
|
||||||
|
This example demonstrates how to build a RAG system on your iMessage conversation history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann.chunking_utils import create_text_chunks
|
||||||
|
|
||||||
|
from apps.base_rag_example import BaseRAGExample
|
||||||
|
from apps.imessage_data.imessage_reader import IMessageReader
|
||||||
|
|
||||||
|
|
||||||
|
class IMessageRAG(BaseRAGExample):
|
||||||
|
"""RAG example for iMessage conversation history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="iMessage",
|
||||||
|
description="RAG on your iMessage conversation history",
|
||||||
|
default_index_name="imessage_index",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add iMessage-specific arguments."""
|
||||||
|
imessage_group = parser.add_argument_group("iMessage Parameters")
|
||||||
|
imessage_group.add_argument(
|
||||||
|
"--db-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
|
||||||
|
)
|
||||||
|
imessage_group.add_argument(
|
||||||
|
"--concatenate-conversations",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Concatenate messages within conversations for better context (default: True)",
|
||||||
|
)
|
||||||
|
imessage_group.add_argument(
|
||||||
|
"--no-concatenate-conversations",
|
||||||
|
action="store_true",
|
||||||
|
help="Process each message individually instead of concatenating by conversation",
|
||||||
|
)
|
||||||
|
imessage_group.add_argument(
|
||||||
|
"--chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Maximum characters per text chunk (default: 1000)",
|
||||||
|
)
|
||||||
|
imessage_group.add_argument(
|
||||||
|
"--chunk-overlap",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="Overlap between text chunks (default: 200)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load iMessage history and convert to text chunks."""
|
||||||
|
print("Loading iMessage conversation history...")
|
||||||
|
|
||||||
|
# Determine concatenation setting
|
||||||
|
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
|
||||||
|
|
||||||
|
# Initialize iMessage reader
|
||||||
|
reader = IMessageReader(concatenate_conversations=concatenate)
|
||||||
|
|
||||||
|
# Load documents
|
||||||
|
try:
|
||||||
|
if args.db_path:
|
||||||
|
# Use custom database path
|
||||||
|
db_dir = str(Path(args.db_path).parent)
|
||||||
|
documents = reader.load_data(input_dir=db_dir)
|
||||||
|
else:
|
||||||
|
# Use default macOS location
|
||||||
|
documents = reader.load_data()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading iMessage data: {e}")
|
||||||
|
print("\nTroubleshooting tips:")
|
||||||
|
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
|
||||||
|
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
|
||||||
|
print("3. Try specifying a custom path with --db-path if you have a backup")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
print("No iMessage conversations found!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Loaded {len(documents)} iMessage documents")
|
||||||
|
|
||||||
|
# Show some statistics
|
||||||
|
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
|
||||||
|
print(f"Total messages: {total_messages}")
|
||||||
|
|
||||||
|
if concatenate:
|
||||||
|
# Show chat statistics
|
||||||
|
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
|
||||||
|
unique_chats = len(set(chat_names))
|
||||||
|
print(f"Unique conversations: {unique_chats}")
|
||||||
|
|
||||||
|
# Convert to text chunks
|
||||||
|
all_texts = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=args.chunk_size,
|
||||||
|
chunk_overlap=args.chunk_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_items limit if specified
|
||||||
|
if args.max_items > 0:
|
||||||
|
all_texts = all_texts[: args.max_items]
|
||||||
|
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
app = IMessageRAG()
|
||||||
|
await app.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
113
apps/multimodal/vision-based-pdf-multi-vector/README.md
Normal file
113
apps/multimodal/vision-based-pdf-multi-vector/README.md
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
|
||||||
|
|
||||||
|
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
|
||||||
|
|
||||||
|
### What you’ll run
|
||||||
|
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
|
||||||
|
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
|
||||||
|
|
||||||
|
## Prerequisites (macOS)
|
||||||
|
|
||||||
|
### 1) Homebrew poppler (for pdf2image)
|
||||||
|
```bash
|
||||||
|
brew install poppler
|
||||||
|
which pdfinfo && pdfinfo -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2) Python environment
|
||||||
|
Use uv (recommended) or pip. Python 3.9+.
|
||||||
|
|
||||||
|
Using uv:
|
||||||
|
```bash
|
||||||
|
uv pip install \
|
||||||
|
colpali_engine \
|
||||||
|
pdf2image \
|
||||||
|
pillow \
|
||||||
|
matplotlib qwen_vl_utils \
|
||||||
|
einops \
|
||||||
|
seaborn
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- On first run, models download from Hugging Face. Login/config if needed.
|
||||||
|
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
|
||||||
|
```bash
|
||||||
|
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run the demos
|
||||||
|
|
||||||
|
### A) Local PDF example
|
||||||
|
Converts a local PDF into page images, embeds them, builds an index, and searches.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||||
|
# If you don't have the sample PDF locally, download it (ignored by Git)
|
||||||
|
mkdir -p pdfs
|
||||||
|
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
|
||||||
|
ls pdfs/2004.12832v2.pdf
|
||||||
|
# Ensure output dir exists
|
||||||
|
mkdir -p pages
|
||||||
|
python multi-vector-leann-paper-example.py
|
||||||
|
```
|
||||||
|
Expected:
|
||||||
|
- Page images in `pages/`.
|
||||||
|
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
|
||||||
|
|
||||||
|
To use your own PDF: edit `pdf_path` near the top of the script.
|
||||||
|
|
||||||
|
### B) Similarity map + answer demo
|
||||||
|
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||||
|
python multi-vector-leann-similarity-map.py
|
||||||
|
```
|
||||||
|
Artifacts (when enabled):
|
||||||
|
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
|
||||||
|
- Similarity maps: `./figures/similarity_map_rank{K}.png`
|
||||||
|
|
||||||
|
Key knobs in the script (top of file):
|
||||||
|
- `QUERY`: your question
|
||||||
|
- `MODEL`: `"colqwen2"` or `"colpali"`
|
||||||
|
- `USE_HF_DATASET`: set `False` to use local pages
|
||||||
|
- `PDF`, `PAGES_DIR`: for local mode
|
||||||
|
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
|
||||||
|
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
|
||||||
|
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
|
||||||
|
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
|
||||||
|
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
|
||||||
|
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
|
||||||
|
- For local PDFs, page images go to `./pages/`.
|
||||||
|
|
||||||
|
|
||||||
|
### Retrieval and Visualization Example
|
||||||
|
|
||||||
|
Example settings in `multi-vector-leann-similarity-map.py`:
|
||||||
|
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
|
||||||
|
- `SIMILARITY_MAP = True` (to generate heatmaps)
|
||||||
|
- `TOPK = 1` (save the top retrieved page and its similarity map)
|
||||||
|
|
||||||
|
Run:
|
||||||
|
```bash
|
||||||
|
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||||
|
python multi-vector-leann-similarity-map.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Outputs (by default):
|
||||||
|
- Retrieved page: `./figures/retrieved_page_rank1.png`
|
||||||
|
- Similarity map: `./figures/similarity_map_rank1.png`
|
||||||
|
|
||||||
|
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
|
||||||
|
"):
|
||||||
|

|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
|
||||||
|
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.
|
||||||
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
@@ -0,0 +1,132 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add the current directory to path to import leann_multi_vector
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Ensure repo paths are importable
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# Set environment variable
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_image():
|
||||||
|
"""Create a simple test image."""
|
||||||
|
# Create a simple RGB image (800x600)
|
||||||
|
img = Image.new("RGB", (800, 600), color="white")
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def load_test_image_from_file():
|
||||||
|
"""Try to load an image from the indexes directory if available."""
|
||||||
|
# Try to find an existing image in the indexes directory
|
||||||
|
indexes_dir = Path(__file__).parent / "indexes"
|
||||||
|
|
||||||
|
# Look for images in common locations
|
||||||
|
possible_paths = [
|
||||||
|
indexes_dir / "vidore_fastplaid" / "images",
|
||||||
|
indexes_dir / "colvision_large.leann.images",
|
||||||
|
indexes_dir / "colvision.leann.images",
|
||||||
|
]
|
||||||
|
|
||||||
|
for img_dir in possible_paths:
|
||||||
|
if img_dir.exists():
|
||||||
|
# Find first image file
|
||||||
|
for ext in [".png", ".jpg", ".jpeg"]:
|
||||||
|
for img_file in img_dir.glob(f"*{ext}"):
|
||||||
|
print(f"Loading test image from: {img_file}")
|
||||||
|
return Image.open(img_file)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("=" * 60)
|
||||||
|
print("Testing ColQwen2 Forward Pass")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Step 1: Load or create test image
|
||||||
|
print("\n[Step 1] Loading test image...")
|
||||||
|
test_image = load_test_image_from_file()
|
||||||
|
if test_image is None:
|
||||||
|
print("No existing image found, creating a simple test image...")
|
||||||
|
test_image = create_test_image()
|
||||||
|
else:
|
||||||
|
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||||
|
|
||||||
|
# Convert to RGB if needed
|
||||||
|
if test_image.mode != "RGB":
|
||||||
|
test_image = test_image.convert("RGB")
|
||||||
|
print(f"✓ Converted to RGB: {test_image.size}")
|
||||||
|
|
||||||
|
# Step 2: Load model
|
||||||
|
print("\n[Step 2] Loading ColQwen2 model...")
|
||||||
|
try:
|
||||||
|
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||||
|
print(f"✓ Model loaded: {model_name}")
|
||||||
|
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||||
|
|
||||||
|
# Print model info
|
||||||
|
if hasattr(model, "device"):
|
||||||
|
print(f"✓ Model device: {model.device}")
|
||||||
|
if hasattr(model, "dtype"):
|
||||||
|
print(f"✓ Model dtype: {model.dtype}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error loading model: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 3: Test forward pass
|
||||||
|
print("\n[Step 3] Running forward pass...")
|
||||||
|
try:
|
||||||
|
# Use the _embed_images function which handles batching and forward pass
|
||||||
|
images = [test_image]
|
||||||
|
print(f"Processing {len(images)} image(s)...")
|
||||||
|
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
|
||||||
|
print("✓ Forward pass completed!")
|
||||||
|
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||||
|
|
||||||
|
if len(doc_vecs) > 0:
|
||||||
|
emb = doc_vecs[0]
|
||||||
|
print(f"✓ Embedding shape: {emb.shape}")
|
||||||
|
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||||
|
print("✓ Embedding stats:")
|
||||||
|
print(f" - Min: {emb.min().item():.4f}")
|
||||||
|
print(f" - Max: {emb.max().item():.4f}")
|
||||||
|
print(f" - Mean: {emb.mean().item():.4f}")
|
||||||
|
print(f" - Std: {emb.std().item():.4f}")
|
||||||
|
|
||||||
|
# Check for NaN or Inf
|
||||||
|
if torch.isnan(emb).any():
|
||||||
|
print("⚠ Warning: Embedding contains NaN values!")
|
||||||
|
if torch.isinf(emb).any():
|
||||||
|
print("⚠ Warning: Embedding contains Inf values!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error during forward pass: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Test completed successfully!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
apps/multimodal/vision-based-pdf-multi-vector/fig/image.png
Normal file
BIN
apps/multimodal/vision-based-pdf-multi-vector/fig/image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 166 KiB |
1452
apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py
Normal file
1452
apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,112 @@
|
|||||||
|
# pip install pdf2image
|
||||||
|
# pip install pymilvus
|
||||||
|
# pip install colpali_engine
|
||||||
|
# pip install tqdm
|
||||||
|
# pip install pillow
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Ensure local leann packages are importable before importing them
|
||||||
|
_repo_root = Path(__file__).resolve().parents[3]
|
||||||
|
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||||
|
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||||
|
if str(_leann_core_src) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_core_src))
|
||||||
|
if str(_leann_hnsw_pkg) not in sys.path:
|
||||||
|
sys.path.append(str(_leann_hnsw_pkg))
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColPali
|
||||||
|
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||||
|
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||||
|
_device_str = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else (
|
||||||
|
"mps"
|
||||||
|
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
device = get_torch_device(_device_str)
|
||||||
|
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||||
|
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||||
|
model_name = "vidore/colpali-v1.2"
|
||||||
|
|
||||||
|
model = ColPali.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=_dtype,
|
||||||
|
device_map=device,
|
||||||
|
).eval()
|
||||||
|
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
"How to end-to-end retrieval with ColBert",
|
||||||
|
"Where is ColBERT performance Table, including text representation results?",
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](queries),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_queries(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
qs: list[torch.Tensor] = []
|
||||||
|
for batch_query in dataloader:
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||||
|
embeddings_query = model(**batch_query)
|
||||||
|
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||||
|
print(qs[0].shape)
|
||||||
|
# %%
|
||||||
|
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||||
|
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ListDataset[str](images),
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=lambda x: processor.process_images(x),
|
||||||
|
)
|
||||||
|
|
||||||
|
ds: list[torch.Tensor] = []
|
||||||
|
for batch_doc in tqdm(dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||||
|
embeddings_doc = model(**batch_doc)
|
||||||
|
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||||
|
|
||||||
|
print(ds[0].shape)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Build HNSW index via LeannRetriever primitives and run search
|
||||||
|
index_path = "./indexes/colpali.leann"
|
||||||
|
retriever = LeannRetriever(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||||
|
retriever.create_collection()
|
||||||
|
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||||
|
for i in range(len(filepaths)):
|
||||||
|
data = {
|
||||||
|
"colbert_vecs": ds[i].float().numpy(),
|
||||||
|
"doc_id": i,
|
||||||
|
"filepath": filepaths[i],
|
||||||
|
}
|
||||||
|
retriever.insert(data)
|
||||||
|
retriever.create_index()
|
||||||
|
for query in qs:
|
||||||
|
query_np = query.float().numpy()
|
||||||
|
result = retriever.search(query_np, topk=1)
|
||||||
|
print(filepaths[result[0][1]])
|
||||||
@@ -0,0 +1,713 @@
|
|||||||
|
## Jupyter-style notebook script
|
||||||
|
# %%
|
||||||
|
# uv pip install matplotlib qwen_vl_utils
|
||||||
|
import argparse
|
||||||
|
import faulthandler
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Enable faulthandler to get stack trace on segfault
|
||||||
|
faulthandler.enable()
|
||||||
|
|
||||||
|
|
||||||
|
from leann_multi_vector import ( # utility functions/classes
|
||||||
|
_ensure_repo_paths_importable,
|
||||||
|
_load_images_from_dir,
|
||||||
|
_maybe_convert_pdf_to_images,
|
||||||
|
_load_colvision,
|
||||||
|
_embed_images,
|
||||||
|
_embed_queries,
|
||||||
|
_build_index,
|
||||||
|
_load_retriever_if_index_exists,
|
||||||
|
_generate_similarity_map,
|
||||||
|
_build_fast_plaid_index,
|
||||||
|
_load_fast_plaid_index_if_exists,
|
||||||
|
_search_fast_plaid,
|
||||||
|
_get_fast_plaid_image,
|
||||||
|
_get_fast_plaid_metadata,
|
||||||
|
QwenVL,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Config
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
|
||||||
|
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||||
|
|
||||||
|
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||||
|
USE_HF_DATASET: bool = True
|
||||||
|
# Single dataset name (used when DATASET_NAMES is None)
|
||||||
|
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||||
|
# Multiple datasets to combine (if provided, DATASET_NAME is ignored)
|
||||||
|
# Can be:
|
||||||
|
# - List of strings: ["dataset1", "dataset2"]
|
||||||
|
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
|
||||||
|
# - Mixed: ["dataset1", ("dataset2", "config2")]
|
||||||
|
#
|
||||||
|
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
|
||||||
|
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
|
||||||
|
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
|
||||||
|
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
|
||||||
|
# - "pixparse/arxiv-papers" (if available, arXiv papers)
|
||||||
|
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
|
||||||
|
# - "huggingface/document-images" (if available)
|
||||||
|
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
|
||||||
|
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||||
|
DATASET_NAMES = [
|
||||||
|
"weaviate/arXiv-AI-papers-multi-vector",
|
||||||
|
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||||
|
]
|
||||||
|
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||||
|
# Set to None to try loading all available splits automatically
|
||||||
|
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
|
||||||
|
# Image field name in the dataset (auto-detect if None)
|
||||||
|
# Common names: "page_image", "image", "images", "img"
|
||||||
|
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
|
||||||
|
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||||
|
|
||||||
|
# Local pages (used when USE_HF_DATASET == False)
|
||||||
|
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||||
|
PAGES_DIR: str = "./pages"
|
||||||
|
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||||
|
# If set, images will be loaded directly from this folder
|
||||||
|
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||||
|
# Whether to recursively search subdirectories when loading from custom folder
|
||||||
|
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||||
|
|
||||||
|
# Index + retrieval settings
|
||||||
|
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||||
|
INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||||
|
# Fast-Plaid index settings (alternative to LEANN index)
|
||||||
|
# These are now command-line arguments (see CLI overrides section)
|
||||||
|
TOPK: int = 3
|
||||||
|
FIRST_STAGE_K: int = 500
|
||||||
|
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||||
|
|
||||||
|
# Artifacts
|
||||||
|
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||||
|
SIMILARITY_MAP: bool = True
|
||||||
|
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||||
|
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||||
|
ANSWER: bool = True
|
||||||
|
MAX_NEW_TOKENS: int = 1024
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# CLI overrides
|
||||||
|
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
|
||||||
|
parser.add_argument(
|
||||||
|
"--search-method",
|
||||||
|
type=str,
|
||||||
|
choices=["ann", "exact", "exact-all"],
|
||||||
|
default="ann",
|
||||||
|
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default=QUERY,
|
||||||
|
help=f"Query string to search for. Default: '{QUERY}'",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fast-plaid",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Set to True to use fast-plaid instead of LEANN. Default: False",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fast-plaid-index-path",
|
||||||
|
type=str,
|
||||||
|
default="./indexes/colvision_fastplaid",
|
||||||
|
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--topk",
|
||||||
|
type=int,
|
||||||
|
default=TOPK,
|
||||||
|
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--custom-folder",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--recursive",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-index",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||||
|
)
|
||||||
|
cli_args, _unknown = parser.parse_known_args()
|
||||||
|
SEARCH_METHOD: str = cli_args.search_method
|
||||||
|
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||||
|
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||||
|
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||||
|
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||||
|
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||||
|
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||||
|
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# Step 1: Check if we can skip data loading (index already exists)
|
||||||
|
retriever: Optional[Any] = None
|
||||||
|
fast_plaid_index: Optional[Any] = None
|
||||||
|
need_to_build_index = REBUILD_INDEX
|
||||||
|
|
||||||
|
if USE_FAST_PLAID:
|
||||||
|
# Fast-Plaid index handling
|
||||||
|
if not REBUILD_INDEX:
|
||||||
|
try:
|
||||||
|
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||||
|
if fast_plaid_index is not None:
|
||||||
|
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
|
||||||
|
need_to_build_index = False
|
||||||
|
else:
|
||||||
|
print(f"Fast-Plaid index not found, will build new index")
|
||||||
|
need_to_build_index = True
|
||||||
|
except Exception as e:
|
||||||
|
# If loading fails (e.g., memory error, corrupted index), rebuild
|
||||||
|
print(f"Warning: Failed to load Fast-Plaid index: {e}")
|
||||||
|
print("Will rebuild the index...")
|
||||||
|
need_to_build_index = True
|
||||||
|
fast_plaid_index = None
|
||||||
|
else:
|
||||||
|
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
|
||||||
|
need_to_build_index = True
|
||||||
|
else:
|
||||||
|
# Original LEANN index handling
|
||||||
|
if not REBUILD_INDEX:
|
||||||
|
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||||
|
if retriever is not None:
|
||||||
|
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||||
|
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||||
|
need_to_build_index = False
|
||||||
|
else:
|
||||||
|
print(f"Index not found, will build new index")
|
||||||
|
need_to_build_index = True
|
||||||
|
else:
|
||||||
|
print(f"REBUILD_INDEX=True, will rebuild index")
|
||||||
|
need_to_build_index = True
|
||||||
|
|
||||||
|
# Step 2: Load data only if we need to build the index
|
||||||
|
if need_to_build_index:
|
||||||
|
print("Loading dataset...")
|
||||||
|
# Check for custom folder path first (takes precedence)
|
||||||
|
if CUSTOM_FOLDER_PATH:
|
||||||
|
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||||
|
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||||
|
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||||
|
if CUSTOM_FOLDER_RECURSIVE:
|
||||||
|
print(" (recursive mode: searching subdirectories)")
|
||||||
|
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||||
|
print(f" Found {len(filepaths)} image files")
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||||
|
)
|
||||||
|
print(f" Successfully loaded {len(images)} images")
|
||||||
|
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||||
|
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||||
|
elif USE_HF_DATASET:
|
||||||
|
from datasets import load_dataset, concatenate_datasets, DatasetDict
|
||||||
|
|
||||||
|
# Determine which datasets to load
|
||||||
|
if DATASET_NAMES is not None:
|
||||||
|
dataset_names_to_load = DATASET_NAMES
|
||||||
|
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
|
||||||
|
else:
|
||||||
|
dataset_names_to_load = [DATASET_NAME]
|
||||||
|
print(f"Loading single dataset: {DATASET_NAME}")
|
||||||
|
|
||||||
|
# Load and combine datasets
|
||||||
|
all_datasets_to_concat = []
|
||||||
|
|
||||||
|
for dataset_entry in dataset_names_to_load:
|
||||||
|
# Handle both string and tuple formats
|
||||||
|
if isinstance(dataset_entry, tuple):
|
||||||
|
dataset_name, config_name = dataset_entry
|
||||||
|
else:
|
||||||
|
dataset_name = dataset_entry
|
||||||
|
config_name = None
|
||||||
|
|
||||||
|
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
|
||||||
|
|
||||||
|
# Load dataset to check available splits
|
||||||
|
# If config_name is provided, use it; otherwise try without config
|
||||||
|
try:
|
||||||
|
if config_name:
|
||||||
|
dataset_dict = load_dataset(dataset_name, config_name)
|
||||||
|
else:
|
||||||
|
dataset_dict = load_dataset(dataset_name)
|
||||||
|
except ValueError as e:
|
||||||
|
if "Config name is missing" in str(e):
|
||||||
|
# Try to get available configs and suggest
|
||||||
|
from datasets import get_dataset_config_names
|
||||||
|
try:
|
||||||
|
available_configs = get_dataset_config_names(dataset_name)
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset '{dataset_name}' requires a config name. "
|
||||||
|
f"Available configs: {available_configs}. "
|
||||||
|
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||||
|
) from e
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset '{dataset_name}' requires a config name. "
|
||||||
|
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Determine which splits to load
|
||||||
|
if DATASET_SPLITS is None:
|
||||||
|
# Auto-detect: try to load all available splits
|
||||||
|
available_splits = list(dataset_dict.keys())
|
||||||
|
print(f" Auto-detected splits: {available_splits}")
|
||||||
|
splits_to_load = available_splits
|
||||||
|
else:
|
||||||
|
splits_to_load = DATASET_SPLITS
|
||||||
|
|
||||||
|
# Load and concatenate multiple splits for this dataset
|
||||||
|
datasets_to_concat = []
|
||||||
|
for split in splits_to_load:
|
||||||
|
if split not in dataset_dict:
|
||||||
|
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
||||||
|
continue
|
||||||
|
split_dataset = dataset_dict[split]
|
||||||
|
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
||||||
|
datasets_to_concat.append(split_dataset)
|
||||||
|
|
||||||
|
if not datasets_to_concat:
|
||||||
|
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Concatenate splits for this dataset
|
||||||
|
if len(datasets_to_concat) > 1:
|
||||||
|
combined_dataset = concatenate_datasets(datasets_to_concat)
|
||||||
|
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
|
||||||
|
else:
|
||||||
|
combined_dataset = datasets_to_concat[0]
|
||||||
|
|
||||||
|
all_datasets_to_concat.append(combined_dataset)
|
||||||
|
|
||||||
|
if not all_datasets_to_concat:
|
||||||
|
raise RuntimeError("No valid datasets or splits found.")
|
||||||
|
|
||||||
|
# Concatenate all datasets
|
||||||
|
if len(all_datasets_to_concat) > 1:
|
||||||
|
dataset = concatenate_datasets(all_datasets_to_concat)
|
||||||
|
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
|
||||||
|
else:
|
||||||
|
dataset = all_datasets_to_concat[0]
|
||||||
|
|
||||||
|
# Apply MAX_DOCS limit if specified
|
||||||
|
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||||
|
if N < len(dataset):
|
||||||
|
print(f"Limiting to {N} pages (from {len(dataset)} total)")
|
||||||
|
dataset = dataset.select(range(N))
|
||||||
|
|
||||||
|
# Auto-detect image field name if not specified
|
||||||
|
if IMAGE_FIELD_NAME is None:
|
||||||
|
# Check multiple samples to find the most common image field
|
||||||
|
# (useful when datasets are merged and may have different field names)
|
||||||
|
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
|
||||||
|
field_counts = {}
|
||||||
|
|
||||||
|
# Check first few samples to find image fields
|
||||||
|
num_samples_to_check = min(10, len(dataset))
|
||||||
|
for sample_idx in range(num_samples_to_check):
|
||||||
|
sample = dataset[sample_idx]
|
||||||
|
for field in possible_image_fields:
|
||||||
|
if field in sample and sample[field] is not None:
|
||||||
|
value = sample[field]
|
||||||
|
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||||
|
field_counts[field] = field_counts.get(field, 0) + 1
|
||||||
|
|
||||||
|
# Choose the most common field, or first found if tied
|
||||||
|
if field_counts:
|
||||||
|
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
|
||||||
|
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
|
||||||
|
else:
|
||||||
|
# Fallback: check first sample only
|
||||||
|
sample = dataset[0]
|
||||||
|
image_field = None
|
||||||
|
for field in possible_image_fields:
|
||||||
|
if field in sample:
|
||||||
|
value = sample[field]
|
||||||
|
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||||
|
image_field = field
|
||||||
|
break
|
||||||
|
if image_field is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
|
||||||
|
f"Please specify IMAGE_FIELD_NAME manually."
|
||||||
|
)
|
||||||
|
print(f"Auto-detected image field: '{image_field}'")
|
||||||
|
else:
|
||||||
|
image_field = IMAGE_FIELD_NAME
|
||||||
|
if image_field not in dataset[0]:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
filepaths: list[str] = []
|
||||||
|
images: list[Image.Image] = []
|
||||||
|
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
|
||||||
|
p = dataset[i]
|
||||||
|
# Try to compose a descriptive identifier
|
||||||
|
# Handle different dataset structures
|
||||||
|
identifier_parts = []
|
||||||
|
|
||||||
|
# Helper function to safely get field value
|
||||||
|
def safe_get(field_name, default=None):
|
||||||
|
if field_name in p and p[field_name] is not None:
|
||||||
|
return p[field_name]
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Try to get various identifier fields
|
||||||
|
if safe_get("paper_arxiv_id"):
|
||||||
|
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
|
||||||
|
if safe_get("paper_title"):
|
||||||
|
identifier_parts.append(f"title:{p['paper_title']}")
|
||||||
|
if safe_get("page_number") is not None:
|
||||||
|
try:
|
||||||
|
identifier_parts.append(f"page:{int(p['page_number'])}")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# If conversion fails, use the raw value or skip
|
||||||
|
if p['page_number']:
|
||||||
|
identifier_parts.append(f"page:{p['page_number']}")
|
||||||
|
if safe_get("page_id"):
|
||||||
|
identifier_parts.append(f"id:{p['page_id']}")
|
||||||
|
elif safe_get("questionId"):
|
||||||
|
identifier_parts.append(f"qid:{p['questionId']}")
|
||||||
|
elif safe_get("docId"):
|
||||||
|
identifier_parts.append(f"docId:{p['docId']}")
|
||||||
|
elif safe_get("id"):
|
||||||
|
identifier_parts.append(f"id:{p['id']}")
|
||||||
|
|
||||||
|
# If no identifier parts found, create one from index
|
||||||
|
if identifier_parts:
|
||||||
|
identifier = "|".join(identifier_parts)
|
||||||
|
else:
|
||||||
|
# Create identifier from available fields or index
|
||||||
|
fallback_parts = []
|
||||||
|
# Try common fields that might exist
|
||||||
|
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
|
||||||
|
if safe_get(field):
|
||||||
|
fallback_parts.append(f"{field}:{p[field]}")
|
||||||
|
break
|
||||||
|
if fallback_parts:
|
||||||
|
identifier = "|".join(fallback_parts) + f"|idx:{i}"
|
||||||
|
else:
|
||||||
|
identifier = f"doc_{i}"
|
||||||
|
|
||||||
|
filepaths.append(identifier)
|
||||||
|
|
||||||
|
# Get image - try detected field first, then fallback to other common fields
|
||||||
|
img = None
|
||||||
|
if image_field in p and p[image_field] is not None:
|
||||||
|
img = p[image_field]
|
||||||
|
else:
|
||||||
|
# Fallback: try other common image field names
|
||||||
|
for fallback_field in ["image", "page_image", "images", "img"]:
|
||||||
|
if fallback_field in p and p[fallback_field] is not None:
|
||||||
|
img = p[fallback_field]
|
||||||
|
break
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
|
||||||
|
f"Expected field: {image_field}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure it's a PIL Image
|
||||||
|
if not isinstance(img, Image.Image):
|
||||||
|
if hasattr(img, 'convert'):
|
||||||
|
img = img.convert('RGB')
|
||||||
|
else:
|
||||||
|
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
|
||||||
|
images.append(img)
|
||||||
|
else:
|
||||||
|
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||||
|
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||||
|
if not images:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||||
|
)
|
||||||
|
print(f"Loaded {len(images)} images")
|
||||||
|
|
||||||
|
# Memory check before loading model
|
||||||
|
try:
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
mem_info = process.memory_info()
|
||||||
|
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||||
|
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("Skipping dataset loading (using existing index)")
|
||||||
|
filepaths = [] # Not needed when using existing index
|
||||||
|
images = [] # Not needed when using existing index
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||||
|
print("Step 3: Loading model and processor...")
|
||||||
|
print(f" Model: {MODEL}")
|
||||||
|
try:
|
||||||
|
import sys
|
||||||
|
print(f" Python version: {sys.version}")
|
||||||
|
print(f" Python executable: {sys.executable}")
|
||||||
|
|
||||||
|
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||||
|
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||||
|
|
||||||
|
# Memory check after loading model
|
||||||
|
try:
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
process = psutil.Process(os.getpid())
|
||||||
|
mem_info = process.memory_info()
|
||||||
|
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||||
|
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error loading model: {type(e).__name__}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 4: Build index if needed
|
||||||
|
if need_to_build_index:
|
||||||
|
print("Step 4: Building index...")
|
||||||
|
print(f" Number of images: {len(images)}")
|
||||||
|
print(f" Number of filepaths: {len(filepaths)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(" Embedding images...")
|
||||||
|
doc_vecs = _embed_images(model, processor, images)
|
||||||
|
print(f" Embedded {len(doc_vecs)} documents")
|
||||||
|
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error embedding images: {type(e).__name__}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
|
||||||
|
if USE_FAST_PLAID:
|
||||||
|
# Build Fast-Plaid index
|
||||||
|
print(" Building Fast-Plaid index...")
|
||||||
|
try:
|
||||||
|
fast_plaid_index, build_secs = _build_fast_plaid_index(
|
||||||
|
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
|
||||||
|
)
|
||||||
|
from pathlib import Path
|
||||||
|
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
|
||||||
|
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
|
||||||
|
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Clear memory
|
||||||
|
print(" Clearing memory...")
|
||||||
|
del images, filepaths, doc_vecs
|
||||||
|
else:
|
||||||
|
# Build original LEANN index
|
||||||
|
try:
|
||||||
|
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||||
|
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error building LEANN index: {type(e).__name__}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Clear memory
|
||||||
|
print(" Clearing memory...")
|
||||||
|
del images, filepaths, doc_vecs
|
||||||
|
|
||||||
|
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 5: Embed query and search
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||||
|
query_embed_secs = time.perf_counter() - _t0
|
||||||
|
|
||||||
|
print(f"[Search] Method: {SEARCH_METHOD}")
|
||||||
|
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
|
||||||
|
|
||||||
|
# Run the selected search method and time it
|
||||||
|
if USE_FAST_PLAID:
|
||||||
|
# Fast-Plaid search
|
||||||
|
if fast_plaid_index is None:
|
||||||
|
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||||
|
if fast_plaid_index is None:
|
||||||
|
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
||||||
|
|
||||||
|
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
||||||
|
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
||||||
|
else:
|
||||||
|
# Original LEANN search
|
||||||
|
query_np = q_vec.float().numpy()
|
||||||
|
|
||||||
|
if SEARCH_METHOD == "ann":
|
||||||
|
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||||
|
search_secs = time.perf_counter() - _t0
|
||||||
|
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||||
|
elif SEARCH_METHOD == "exact":
|
||||||
|
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||||
|
search_secs = time.perf_counter() - _t0
|
||||||
|
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||||
|
elif SEARCH_METHOD == "exact-all":
|
||||||
|
results = retriever.search_exact_all(query_np, topk=TOPK)
|
||||||
|
search_secs = time.perf_counter() - _t0
|
||||||
|
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
if not results:
|
||||||
|
print("No results found.")
|
||||||
|
else:
|
||||||
|
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||||
|
print("\n[DEBUG] Retrieval details:")
|
||||||
|
top_images: list[Image.Image] = []
|
||||||
|
image_hashes = {} # Track image hashes to detect duplicates
|
||||||
|
|
||||||
|
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||||
|
# Retrieve image and metadata based on index type
|
||||||
|
if USE_FAST_PLAID:
|
||||||
|
# Fast-Plaid: load image and get metadata
|
||||||
|
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
|
||||||
|
if image is None:
|
||||||
|
print(f"Warning: Could not find image for doc_id {doc_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
||||||
|
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
||||||
|
top_images.append(image)
|
||||||
|
else:
|
||||||
|
# Original LEANN: retrieve from retriever
|
||||||
|
image = retriever.get_image(doc_id)
|
||||||
|
if image is None:
|
||||||
|
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = retriever.get_metadata(doc_id)
|
||||||
|
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||||
|
top_images.append(image)
|
||||||
|
|
||||||
|
# Calculate image hash to detect duplicates
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
# Convert image to bytes for hashing
|
||||||
|
img_bytes = io.BytesIO()
|
||||||
|
image.save(img_bytes, format='PNG')
|
||||||
|
image_bytes = img_bytes.getvalue()
|
||||||
|
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
|
||||||
|
|
||||||
|
# Check if this image was already seen
|
||||||
|
duplicate_info = ""
|
||||||
|
if image_hash in image_hashes:
|
||||||
|
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
|
||||||
|
else:
|
||||||
|
image_hashes[image_hash] = rank
|
||||||
|
|
||||||
|
# Print detailed information
|
||||||
|
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
|
||||||
|
if metadata:
|
||||||
|
print(f" Metadata: {metadata}")
|
||||||
|
|
||||||
|
if SAVE_TOP_IMAGE:
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
base = _Path(SAVE_TOP_IMAGE)
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if base.suffix:
|
||||||
|
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||||
|
else:
|
||||||
|
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||||
|
img.save(str(out_path))
|
||||||
|
# Print the retrieval score (document-level MaxSim) alongside the saved path
|
||||||
|
try:
|
||||||
|
score, _doc_id = results[rank - 1]
|
||||||
|
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
|
||||||
|
except Exception:
|
||||||
|
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 6: Similarity maps for top-K results
|
||||||
|
if results and SIMILARITY_MAP:
|
||||||
|
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||||
|
from pathlib import Path as _Path
|
||||||
|
|
||||||
|
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||||
|
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||||
|
if output_base:
|
||||||
|
if output_base.suffix:
|
||||||
|
out_dir = output_base.parent
|
||||||
|
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||||
|
out_path = str(out_dir / out_name)
|
||||||
|
else:
|
||||||
|
out_dir = output_base
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||||
|
else:
|
||||||
|
out_path = None
|
||||||
|
chosen_idx, max_sim = _generate_similarity_map(
|
||||||
|
model=model,
|
||||||
|
processor=processor,
|
||||||
|
image=img,
|
||||||
|
query=QUERY,
|
||||||
|
token_idx=token_idx,
|
||||||
|
output_path=out_path,
|
||||||
|
)
|
||||||
|
if out_path:
|
||||||
|
print(
|
||||||
|
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# Step 7: Optional answer generation
|
||||||
|
if results and ANSWER:
|
||||||
|
qwen = QwenVL(device=device_str)
|
||||||
|
_t0 = time.perf_counter()
|
||||||
|
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||||
|
gen_secs = time.perf_counter() - _t0
|
||||||
|
print(f"[Timing] Generation: {gen_secs:.3f}s")
|
||||||
|
print("\nAnswer:")
|
||||||
|
print(response)
|
||||||
@@ -0,0 +1,448 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||||
|
|
||||||
|
This script uses the interface from leann_multi_vector.py to:
|
||||||
|
1. Download ViDoRe v1 datasets
|
||||||
|
2. Build indexes (LEANN or Fast-Plaid)
|
||||||
|
3. Perform retrieval
|
||||||
|
4. Evaluate using NDCG metrics
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Evaluate all ViDoRe v1 tasks
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||||
|
|
||||||
|
# Evaluate specific task
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||||
|
|
||||||
|
# Use Fast-Plaid index
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||||
|
|
||||||
|
# Rebuild index
|
||||||
|
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from leann_multi_vector import (
|
||||||
|
ViDoReBenchmarkEvaluator,
|
||||||
|
_ensure_repo_paths_importable,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# ViDoRe v1 task configurations
|
||||||
|
# Prompts match MTEB task metadata prompts
|
||||||
|
VIDORE_V1_TASKS = {
|
||||||
|
"VidoreArxivQARetrieval": {
|
||||||
|
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||||
|
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreDocVQARetrieval": {
|
||||||
|
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||||
|
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreInfoVQARetrieval": {
|
||||||
|
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||||
|
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreTabfquadRetrieval": {
|
||||||
|
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||||
|
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreTatdqaRetrieval": {
|
||||||
|
"dataset_path": "vidore/tatdqa_test_beir",
|
||||||
|
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreShiftProjectRetrieval": {
|
||||||
|
"dataset_path": "vidore/shiftproject_test_beir",
|
||||||
|
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAAIRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||||
|
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||||
|
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||||
|
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||||
|
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||||
|
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Task name aliases (short names -> full names)
|
||||||
|
TASK_ALIASES = {
|
||||||
|
"arxivqa": "VidoreArxivQARetrieval",
|
||||||
|
"docvqa": "VidoreDocVQARetrieval",
|
||||||
|
"infovqa": "VidoreInfoVQARetrieval",
|
||||||
|
"tabfquad": "VidoreTabfquadRetrieval",
|
||||||
|
"tatdqa": "VidoreTatdqaRetrieval",
|
||||||
|
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||||
|
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||||
|
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||||
|
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||||
|
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_task_name(task_name: str) -> str:
|
||||||
|
"""Normalize task name (handle aliases)."""
|
||||||
|
task_name_lower = task_name.lower()
|
||||||
|
if task_name in VIDORE_V1_TASKS:
|
||||||
|
return task_name
|
||||||
|
if task_name_lower in TASK_ALIASES:
|
||||||
|
return TASK_ALIASES[task_name_lower]
|
||||||
|
# Try partial match
|
||||||
|
for alias, full_name in TASK_ALIASES.items():
|
||||||
|
if alias in task_name_lower or task_name_lower in alias:
|
||||||
|
return full_name
|
||||||
|
return task_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_safe_model_name(model_name: str) -> str:
|
||||||
|
"""Get a safe model name for use in file paths."""
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
# If it's a path, use basename or hash
|
||||||
|
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||||
|
# Use basename if it's reasonable, otherwise use hash
|
||||||
|
basename = os.path.basename(model_name.rstrip("/"))
|
||||||
|
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||||
|
return basename
|
||||||
|
# Use hash for very long or problematic paths
|
||||||
|
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||||
|
# For HuggingFace model names, replace / with _
|
||||||
|
return model_name.replace("/", "_").replace(":", "_")
|
||||||
|
|
||||||
|
|
||||||
|
def load_vidore_v1_data(
|
||||||
|
dataset_path: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
split: str = "test",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load ViDoRe v1 dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
|
queries: dict mapping query_id to query text
|
||||||
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
|
"""
|
||||||
|
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||||
|
|
||||||
|
# Load queries
|
||||||
|
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||||
|
|
||||||
|
queries = {}
|
||||||
|
for row in query_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
queries[query_id] = row["query"]
|
||||||
|
|
||||||
|
# Load corpus (images)
|
||||||
|
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||||
|
|
||||||
|
corpus = {}
|
||||||
|
for row in corpus_ds:
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
# Extract image from the dataset row
|
||||||
|
if "image" in row:
|
||||||
|
corpus[corpus_id] = row["image"]
|
||||||
|
elif "page_image" in row:
|
||||||
|
corpus[corpus_id] = row["page_image"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load qrels (relevance judgments)
|
||||||
|
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||||
|
|
||||||
|
qrels = {}
|
||||||
|
for row in qrels_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
if query_id not in qrels:
|
||||||
|
qrels[query_id] = {}
|
||||||
|
qrels[query_id][corpus_id] = int(row["score"])
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter qrels to only include queries that exist
|
||||||
|
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||||
|
|
||||||
|
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||||
|
# This is important for correct NDCG calculation
|
||||||
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||||
|
queries_filtered = {
|
||||||
|
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
return corpus, queries_filtered, qrels_filtered
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_task(
|
||||||
|
task_name: str,
|
||||||
|
model_name: str,
|
||||||
|
index_path: str,
|
||||||
|
use_fast_plaid: bool = False,
|
||||||
|
fast_plaid_index_path: Optional[str] = None,
|
||||||
|
rebuild_index: bool = False,
|
||||||
|
top_k: int = 1000,
|
||||||
|
first_stage_k: int = 500,
|
||||||
|
k_values: Optional[list[int]] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Evaluate a single ViDoRe v1 task.
|
||||||
|
"""
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Evaluating task: {task_name}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Normalize task name (handle aliases)
|
||||||
|
task_name = normalize_task_name(task_name)
|
||||||
|
|
||||||
|
# Get task config
|
||||||
|
if task_name not in VIDORE_V1_TASKS:
|
||||||
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||||
|
|
||||||
|
task_config = VIDORE_V1_TASKS[task_name]
|
||||||
|
dataset_path = task_config["dataset_path"]
|
||||||
|
revision = task_config["revision"]
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
corpus, queries, qrels = load_vidore_v1_data(
|
||||||
|
dataset_path=dataset_path,
|
||||||
|
revision=revision,
|
||||||
|
split="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize k_values if not provided
|
||||||
|
if k_values is None:
|
||||||
|
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||||
|
|
||||||
|
# Check if we have any queries
|
||||||
|
if len(queries) == 0:
|
||||||
|
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||||
|
# Return zero scores
|
||||||
|
scores = {}
|
||||||
|
for k in k_values:
|
||||||
|
scores[f"ndcg_at_{k}"] = 0.0
|
||||||
|
scores[f"map_at_{k}"] = 0.0
|
||||||
|
scores[f"recall_at_{k}"] = 0.0
|
||||||
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# Initialize evaluator
|
||||||
|
evaluator = ViDoReBenchmarkEvaluator(
|
||||||
|
model_name=model_name,
|
||||||
|
use_fast_plaid=use_fast_plaid,
|
||||||
|
top_k=top_k,
|
||||||
|
first_stage_k=first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build or load index
|
||||||
|
# Use safe model name for index path (different models need different indexes)
|
||||||
|
safe_model_name = get_safe_model_name(model_name)
|
||||||
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
|
if index_path_full is None:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||||
|
if use_fast_plaid:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||||
|
|
||||||
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
|
corpus=corpus,
|
||||||
|
index_path=index_path_full,
|
||||||
|
rebuild=rebuild_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search queries
|
||||||
|
task_prompt = task_config.get("prompt")
|
||||||
|
results = evaluator.search_queries(
|
||||||
|
queries=queries,
|
||||||
|
corpus_ids=corpus_ids_ordered,
|
||||||
|
index_or_retriever=index_or_retriever,
|
||||||
|
fast_plaid_index_path=fast_plaid_index_path,
|
||||||
|
task_prompt=task_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Results for {task_name}:")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for metric, value in scores.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
print(f" {metric}: {value:.5f}")
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
if output_dir:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||||
|
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||||
|
|
||||||
|
with open(results_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print(f"\nSaved results to: {results_file}")
|
||||||
|
|
||||||
|
with open(scores_file, "w") as f:
|
||||||
|
json.dump(scores, f, indent=2)
|
||||||
|
print(f"Saved scores to: {scores_file}")
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="colqwen2",
|
||||||
|
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tasks",
|
||||||
|
type=str,
|
||||||
|
default="all",
|
||||||
|
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to LEANN index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fast-plaid",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Fast-Plaid instead of LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fast-plaid-index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-index",
|
||||||
|
action="store_true",
|
||||||
|
help="Rebuild index even if it exists",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--first-stage-k",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="First stage k for LEANN search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k-values",
|
||||||
|
type=str,
|
||||||
|
default="1,3,5,10,20,100,1000",
|
||||||
|
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="./vidore_v1_results",
|
||||||
|
help="Output directory for results",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Parse k_values
|
||||||
|
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||||
|
|
||||||
|
# Determine tasks to evaluate
|
||||||
|
if args.task:
|
||||||
|
tasks_to_eval = [normalize_task_name(args.task)]
|
||||||
|
elif args.tasks.lower() == "all":
|
||||||
|
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||||
|
else:
|
||||||
|
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||||
|
|
||||||
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
|
# Evaluate each task
|
||||||
|
all_scores = {}
|
||||||
|
for task_name in tasks_to_eval:
|
||||||
|
try:
|
||||||
|
scores = evaluate_task(
|
||||||
|
task_name=task_name,
|
||||||
|
model_name=args.model,
|
||||||
|
index_path=args.index_path,
|
||||||
|
use_fast_plaid=args.use_fast_plaid,
|
||||||
|
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||||
|
rebuild_index=args.rebuild_index,
|
||||||
|
top_k=args.top_k,
|
||||||
|
first_stage_k=args.first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
)
|
||||||
|
all_scores[task_name] = scores
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError evaluating {task_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
if all_scores:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("SUMMARY")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for task_name, scores in all_scores.items():
|
||||||
|
print(f"\n{task_name}:")
|
||||||
|
# Print main metrics
|
||||||
|
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||||
|
if metric in scores:
|
||||||
|
print(f" {metric}: {scores[metric]:.5f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,439 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
|
||||||
|
|
||||||
|
This script uses the interface from leann_multi_vector.py to:
|
||||||
|
1. Download ViDoRe v2 datasets
|
||||||
|
2. Build indexes (LEANN or Fast-Plaid)
|
||||||
|
3. Perform retrieval
|
||||||
|
4. Evaluate using NDCG metrics
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Evaluate all ViDoRe v2 tasks
|
||||||
|
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||||
|
|
||||||
|
# Evaluate specific task
|
||||||
|
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||||
|
|
||||||
|
# Use Fast-Plaid index
|
||||||
|
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||||
|
|
||||||
|
# Rebuild index
|
||||||
|
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from leann_multi_vector import (
|
||||||
|
ViDoReBenchmarkEvaluator,
|
||||||
|
_ensure_repo_paths_importable,
|
||||||
|
)
|
||||||
|
|
||||||
|
_ensure_repo_paths_importable(__file__)
|
||||||
|
|
||||||
|
# Language name to dataset language field value mapping
|
||||||
|
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
|
||||||
|
LANGUAGE_MAPPING = {
|
||||||
|
"english": "eng-Latn",
|
||||||
|
"french": "fra-Latn",
|
||||||
|
"spanish": "spa-Latn",
|
||||||
|
"german": "deu-Latn",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ViDoRe v2 task configurations
|
||||||
|
# Prompts match MTEB task metadata prompts
|
||||||
|
VIDORE_V2_TASKS = {
|
||||||
|
"Vidore2ESGReportsRetrieval": {
|
||||||
|
"dataset_path": "vidore/esg_reports_v2",
|
||||||
|
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
|
||||||
|
"languages": ["french", "spanish", "english", "german"],
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"Vidore2EconomicsReportsRetrieval": {
|
||||||
|
"dataset_path": "vidore/economics_reports_v2",
|
||||||
|
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
|
||||||
|
"languages": ["french", "spanish", "english", "german"],
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"Vidore2BioMedicalLecturesRetrieval": {
|
||||||
|
"dataset_path": "vidore/biomedical_lectures_v2",
|
||||||
|
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
|
||||||
|
"languages": ["french", "spanish", "english", "german"],
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
"Vidore2ESGReportsHLRetrieval": {
|
||||||
|
"dataset_path": "vidore/esg_reports_human_labeled_v2",
|
||||||
|
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
|
||||||
|
# Note: This dataset doesn't have language filtering - all queries are English
|
||||||
|
"languages": None, # No language filtering needed
|
||||||
|
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_vidore_v2_data(
|
||||||
|
dataset_path: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
split: str = "test",
|
||||||
|
language: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load ViDoRe v2 dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
corpus: dict mapping corpus_id to PIL Image
|
||||||
|
queries: dict mapping query_id to query text
|
||||||
|
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||||
|
"""
|
||||||
|
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||||
|
|
||||||
|
# Load queries
|
||||||
|
query_ds = load_dataset(dataset_path, "queries", split=split, revision=revision)
|
||||||
|
|
||||||
|
# Check if dataset has language field before filtering
|
||||||
|
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||||
|
|
||||||
|
if language and has_language_field:
|
||||||
|
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||||
|
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||||
|
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||||
|
# Check if filtering resulted in empty dataset
|
||||||
|
if len(query_ds_filtered) == 0:
|
||||||
|
print(
|
||||||
|
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||||
|
)
|
||||||
|
# Try with original language value (dataset might use simple names like 'english')
|
||||||
|
print(f"Trying with original language value '{language}'...")
|
||||||
|
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||||
|
if len(query_ds_filtered) == 0:
|
||||||
|
# Try to get a sample to see actual language values
|
||||||
|
try:
|
||||||
|
sample_ds = load_dataset(
|
||||||
|
dataset_path, "queries", split=split, revision=revision
|
||||||
|
)
|
||||||
|
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||||
|
sample_langs = set(sample_ds["language"])
|
||||||
|
print(f"Available language values in dataset: {sample_langs}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||||
|
)
|
||||||
|
query_ds = query_ds_filtered
|
||||||
|
|
||||||
|
queries = {}
|
||||||
|
for row in query_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
queries[query_id] = row["query"]
|
||||||
|
|
||||||
|
# Load corpus (images)
|
||||||
|
corpus_ds = load_dataset(dataset_path, "corpus", split=split, revision=revision)
|
||||||
|
|
||||||
|
corpus = {}
|
||||||
|
for row in corpus_ds:
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
# Extract image from the dataset row
|
||||||
|
if "image" in row:
|
||||||
|
corpus[corpus_id] = row["image"]
|
||||||
|
elif "page_image" in row:
|
||||||
|
corpus[corpus_id] = row["page_image"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"No image field found in corpus. Available fields: {list(row.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load qrels (relevance judgments)
|
||||||
|
qrels_ds = load_dataset(dataset_path, "qrels", split=split, revision=revision)
|
||||||
|
|
||||||
|
qrels = {}
|
||||||
|
for row in qrels_ds:
|
||||||
|
query_id = f"query-{split}-{row['query-id']}"
|
||||||
|
corpus_id = f"corpus-{split}-{row['corpus-id']}"
|
||||||
|
if query_id not in qrels:
|
||||||
|
qrels[query_id] = {}
|
||||||
|
qrels[query_id][corpus_id] = int(row["score"])
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter qrels to only include queries that exist
|
||||||
|
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||||
|
|
||||||
|
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||||
|
# This is important for correct NDCG calculation
|
||||||
|
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||||
|
queries_filtered = {
|
||||||
|
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||||
|
)
|
||||||
|
|
||||||
|
return corpus, queries_filtered, qrels_filtered
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_task(
|
||||||
|
task_name: str,
|
||||||
|
model_name: str,
|
||||||
|
index_path: str,
|
||||||
|
use_fast_plaid: bool = False,
|
||||||
|
fast_plaid_index_path: Optional[str] = None,
|
||||||
|
language: Optional[str] = None,
|
||||||
|
rebuild_index: bool = False,
|
||||||
|
top_k: int = 100,
|
||||||
|
first_stage_k: int = 500,
|
||||||
|
k_values: Optional[list[int]] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Evaluate a single ViDoRe v2 task.
|
||||||
|
"""
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Evaluating task: {task_name}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Get task config
|
||||||
|
if task_name not in VIDORE_V2_TASKS:
|
||||||
|
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||||
|
|
||||||
|
task_config = VIDORE_V2_TASKS[task_name]
|
||||||
|
dataset_path = task_config["dataset_path"]
|
||||||
|
revision = task_config["revision"]
|
||||||
|
|
||||||
|
# Determine language
|
||||||
|
if language is None:
|
||||||
|
# Use first language if multiple available
|
||||||
|
languages = task_config.get("languages")
|
||||||
|
if languages is None:
|
||||||
|
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||||
|
language = None
|
||||||
|
elif len(languages) == 1:
|
||||||
|
language = languages[0]
|
||||||
|
else:
|
||||||
|
language = None
|
||||||
|
|
||||||
|
# Initialize k_values if not provided
|
||||||
|
if k_values is None:
|
||||||
|
k_values = [1, 3, 5, 10, 100]
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
corpus, queries, qrels = load_vidore_v2_data(
|
||||||
|
dataset_path=dataset_path,
|
||||||
|
revision=revision,
|
||||||
|
split="test",
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if we have any queries
|
||||||
|
if len(queries) == 0:
|
||||||
|
print(
|
||||||
|
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||||
|
)
|
||||||
|
# Return zero scores
|
||||||
|
scores = {}
|
||||||
|
for k in k_values:
|
||||||
|
scores[f"ndcg_at_{k}"] = 0.0
|
||||||
|
scores[f"map_at_{k}"] = 0.0
|
||||||
|
scores[f"recall_at_{k}"] = 0.0
|
||||||
|
scores[f"precision_at_{k}"] = 0.0
|
||||||
|
scores[f"mrr_at_{k}"] = 0.0
|
||||||
|
return scores
|
||||||
|
|
||||||
|
# Initialize evaluator
|
||||||
|
evaluator = ViDoReBenchmarkEvaluator(
|
||||||
|
model_name=model_name,
|
||||||
|
use_fast_plaid=use_fast_plaid,
|
||||||
|
top_k=top_k,
|
||||||
|
first_stage_k=first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build or load index
|
||||||
|
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||||
|
if index_path_full is None:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||||
|
if use_fast_plaid:
|
||||||
|
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||||
|
|
||||||
|
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||||
|
corpus=corpus,
|
||||||
|
index_path=index_path_full,
|
||||||
|
rebuild=rebuild_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search queries
|
||||||
|
task_prompt = task_config.get("prompt")
|
||||||
|
results = evaluator.search_queries(
|
||||||
|
queries=queries,
|
||||||
|
corpus_ids=corpus_ids_ordered,
|
||||||
|
index_or_retriever=index_or_retriever,
|
||||||
|
fast_plaid_index_path=fast_plaid_index_path,
|
||||||
|
task_prompt=task_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Results for {task_name}:")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for metric, value in scores.items():
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
print(f" {metric}: {value:.5f}")
|
||||||
|
|
||||||
|
# Save results
|
||||||
|
if output_dir:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||||
|
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||||
|
|
||||||
|
with open(results_file, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print(f"\nSaved results to: {results_file}")
|
||||||
|
|
||||||
|
with open(scores_file, "w") as f:
|
||||||
|
json.dump(scores, f, indent=2)
|
||||||
|
print(f"Saved scores to: {scores_file}")
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default="colqwen2",
|
||||||
|
choices=["colqwen2", "colpali"],
|
||||||
|
help="Model to use",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--task",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tasks",
|
||||||
|
type=str,
|
||||||
|
default="all",
|
||||||
|
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to LEANN index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fast-plaid",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Fast-Plaid instead of LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fast-plaid-index-path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rebuild-index",
|
||||||
|
action="store_true",
|
||||||
|
help="Rebuild index even if it exists",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--language",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Language to evaluate (default: first available)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Top-k results to retrieve",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--first-stage-k",
|
||||||
|
type=int,
|
||||||
|
default=500,
|
||||||
|
help="First stage k for LEANN search",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--k-values",
|
||||||
|
type=str,
|
||||||
|
default="1,3,5,10,100",
|
||||||
|
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
default="./vidore_v2_results",
|
||||||
|
help="Output directory for results",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Parse k_values
|
||||||
|
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||||
|
|
||||||
|
# Determine tasks to evaluate
|
||||||
|
if args.task:
|
||||||
|
tasks_to_eval = [args.task]
|
||||||
|
elif args.tasks.lower() == "all":
|
||||||
|
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||||
|
else:
|
||||||
|
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||||
|
|
||||||
|
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||||
|
|
||||||
|
# Evaluate each task
|
||||||
|
all_scores = {}
|
||||||
|
for task_name in tasks_to_eval:
|
||||||
|
try:
|
||||||
|
scores = evaluate_task(
|
||||||
|
task_name=task_name,
|
||||||
|
model_name=args.model,
|
||||||
|
index_path=args.index_path,
|
||||||
|
use_fast_plaid=args.use_fast_plaid,
|
||||||
|
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||||
|
language=args.language,
|
||||||
|
rebuild_index=args.rebuild_index,
|
||||||
|
top_k=args.top_k,
|
||||||
|
first_stage_k=args.first_stage_k,
|
||||||
|
k_values=k_values,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
)
|
||||||
|
all_scores[task_name] = scores
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError evaluating {task_name}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
if all_scores:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("SUMMARY")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for task_name, scores in all_scores.items():
|
||||||
|
print(f"\n{task_name}:")
|
||||||
|
# Print main metrics
|
||||||
|
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||||
|
if metric in scores:
|
||||||
|
print(f" {metric}: {scores[metric]:.5f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
183
apps/semantic_file_search/leann-plus-temporal-search.py
Normal file
183
apps/semantic_file_search/leann-plus-temporal-search.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||||
|
|
||||||
|
|
||||||
|
class TimeParser:
|
||||||
|
def __init__(self):
|
||||||
|
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
|
||||||
|
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
|
||||||
|
|
||||||
|
# Compile for performance
|
||||||
|
self.regex = re.compile(self.pattern, re.IGNORECASE)
|
||||||
|
|
||||||
|
# Stop words to remove before regex parsing
|
||||||
|
self.stop_words = {
|
||||||
|
"in",
|
||||||
|
"at",
|
||||||
|
"of",
|
||||||
|
"by",
|
||||||
|
"as",
|
||||||
|
"me",
|
||||||
|
"the",
|
||||||
|
"a",
|
||||||
|
"an",
|
||||||
|
"and",
|
||||||
|
"any",
|
||||||
|
"find",
|
||||||
|
"search",
|
||||||
|
"list",
|
||||||
|
"ago",
|
||||||
|
"back",
|
||||||
|
"past",
|
||||||
|
"earlier",
|
||||||
|
}
|
||||||
|
|
||||||
|
def clean_text(self, text):
|
||||||
|
"""Remove stop words from text"""
|
||||||
|
words = text.split()
|
||||||
|
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def parse(self, text):
|
||||||
|
"""Extract all time expressions from text"""
|
||||||
|
# Clean text first
|
||||||
|
cleaned_text = self.clean_text(text)
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for match in self.regex.finditer(cleaned_text):
|
||||||
|
fuzzy = match.group(1) # "around", "about", etc.
|
||||||
|
number = int(match.group(2))
|
||||||
|
unit = match.group(3).lower()
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
{
|
||||||
|
"full_match": match.group(0),
|
||||||
|
"fuzzy": bool(fuzzy),
|
||||||
|
"number": number,
|
||||||
|
"unit": unit,
|
||||||
|
"range": self.calculate_range(number, unit, bool(fuzzy)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def calculate_range(self, number, unit, is_fuzzy):
|
||||||
|
"""Convert to actual datetime range and return ISO format strings"""
|
||||||
|
units = {
|
||||||
|
"hour": timedelta(hours=number),
|
||||||
|
"day": timedelta(days=number),
|
||||||
|
"week": timedelta(weeks=number),
|
||||||
|
"month": timedelta(days=number * 30),
|
||||||
|
"year": timedelta(days=number * 365),
|
||||||
|
}
|
||||||
|
|
||||||
|
delta = units[unit]
|
||||||
|
now = datetime.now()
|
||||||
|
target = now - delta
|
||||||
|
|
||||||
|
if is_fuzzy:
|
||||||
|
buffer = delta * 0.2 # 20% buffer for fuzzy
|
||||||
|
start = (target - buffer).isoformat()
|
||||||
|
end = (target + buffer).isoformat()
|
||||||
|
else:
|
||||||
|
start = target.isoformat()
|
||||||
|
end = now.isoformat()
|
||||||
|
|
||||||
|
return (start, end)
|
||||||
|
|
||||||
|
|
||||||
|
def search_files(query, top_k=15):
|
||||||
|
"""Search the index and return results"""
|
||||||
|
# Parse time expressions
|
||||||
|
parser = TimeParser()
|
||||||
|
time_matches = parser.parse(query)
|
||||||
|
|
||||||
|
# Remove time expressions from query for semantic search
|
||||||
|
clean_query = query
|
||||||
|
if time_matches:
|
||||||
|
for match in time_matches:
|
||||||
|
clean_query = clean_query.replace(match["full_match"], "").strip()
|
||||||
|
|
||||||
|
# Check if clean_query is less than 4 characters
|
||||||
|
if len(clean_query) < 4:
|
||||||
|
print("Error: add more input for accurate results.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Single query to vector DB
|
||||||
|
searcher = LeannSearcher(INDEX_PATH)
|
||||||
|
results = searcher.search(
|
||||||
|
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter by time if time expression found
|
||||||
|
if time_matches:
|
||||||
|
time_range = time_matches[0]["range"] # Use first time expression
|
||||||
|
start_time, end_time = time_range
|
||||||
|
|
||||||
|
filtered_results = []
|
||||||
|
for result in results:
|
||||||
|
# Access metadata attribute directly (not .get())
|
||||||
|
metadata = result.metadata if hasattr(result, "metadata") else {}
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
# Check modification date first, fall back to creation date
|
||||||
|
date_str = metadata.get("modification_date") or metadata.get("creation_date")
|
||||||
|
|
||||||
|
if date_str:
|
||||||
|
# Convert strings to datetime objects for proper comparison
|
||||||
|
try:
|
||||||
|
file_date = datetime.fromisoformat(date_str)
|
||||||
|
start_dt = datetime.fromisoformat(start_time)
|
||||||
|
end_dt = datetime.fromisoformat(end_time)
|
||||||
|
|
||||||
|
# Compare dates properly
|
||||||
|
if start_dt <= file_date <= end_dt:
|
||||||
|
filtered_results.append(result)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# Handle invalid date formats
|
||||||
|
print(f"Warning: Invalid date format in metadata: {date_str}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
results = filtered_results
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"\nSearch results for: '{query}'")
|
||||||
|
if time_matches:
|
||||||
|
print(
|
||||||
|
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
|
||||||
|
)
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
print(f"\n[{i}] Score: {result.score:.4f}")
|
||||||
|
print(f"Content: {result.text}")
|
||||||
|
|
||||||
|
# Show metadata if present
|
||||||
|
metadata = result.metadata if hasattr(result, "metadata") else None
|
||||||
|
if metadata:
|
||||||
|
if "creation_date" in metadata:
|
||||||
|
print(f"Created: {metadata['creation_date']}")
|
||||||
|
if "modification_date" in metadata:
|
||||||
|
print(f"Modified: {metadata['modification_date']}")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print('Usage: python search_index.py "<search query>" [top_k]')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
query = sys.argv[1]
|
||||||
|
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
|
||||||
|
|
||||||
|
search_files(query, top_k)
|
||||||
82
apps/semantic_file_search/leann_index_builder.py
Normal file
82
apps/semantic_file_search/leann_index_builder.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
|
def process_json_items(json_file_path):
|
||||||
|
"""Load and process JSON file with metadata items"""
|
||||||
|
|
||||||
|
with open(json_file_path, encoding="utf-8") as f:
|
||||||
|
items = json.load(f)
|
||||||
|
|
||||||
|
# Guard against empty JSON
|
||||||
|
if not items:
|
||||||
|
print("⚠️ No items found in the JSON file. Exiting gracefully.")
|
||||||
|
return
|
||||||
|
|
||||||
|
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||||
|
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
|
||||||
|
|
||||||
|
total_items = len(items)
|
||||||
|
items_added = 0
|
||||||
|
print(f"Processing {total_items} items...")
|
||||||
|
|
||||||
|
for idx, item in enumerate(items):
|
||||||
|
try:
|
||||||
|
# Create embedding text sentence
|
||||||
|
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
|
||||||
|
|
||||||
|
# Prepare metadata with dates
|
||||||
|
metadata = {}
|
||||||
|
if "CreationDate" in item:
|
||||||
|
metadata["creation_date"] = item["CreationDate"]
|
||||||
|
if "ContentChangeDate" in item:
|
||||||
|
metadata["modification_date"] = item["ContentChangeDate"]
|
||||||
|
|
||||||
|
# Add to builder
|
||||||
|
builder.add_text(embedding_text, metadata=metadata)
|
||||||
|
items_added += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Show progress
|
||||||
|
progress = (idx + 1) / total_items * 100
|
||||||
|
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
print() # New line after progress
|
||||||
|
|
||||||
|
# Guard against no successfully added items
|
||||||
|
if items_added == 0:
|
||||||
|
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
|
||||||
|
print("Building index...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
builder.build_index(INDEX_PATH)
|
||||||
|
print(f"✓ Index saved to {INDEX_PATH}")
|
||||||
|
except ValueError as e:
|
||||||
|
if "No chunks added" in str(e):
|
||||||
|
print("⚠️ No chunks were added to the builder. Index not created.")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) != 2:
|
||||||
|
print("Usage: python build_index.py <json_file>")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
json_file = sys.argv[1]
|
||||||
|
if not Path(json_file).exists():
|
||||||
|
print(f"Error: File {json_file} not found")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
process_json_items(json_file)
|
||||||
265
apps/semantic_file_search/spotlight_index_dump.py
Normal file
265
apps/semantic_file_search/spotlight_index_dump.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Spotlight Metadata Dumper for Vector DB
|
||||||
|
Extracts only essential metadata for semantic search embeddings
|
||||||
|
Output is optimized for vector database storage with minimal fields
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Check platform before importing macOS-specific modules
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("This script requires macOS (uses Spotlight)")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
|
||||||
|
|
||||||
|
# EDIT THIS LIST: Add or remove folders to search
|
||||||
|
# Can be either:
|
||||||
|
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
|
||||||
|
# - Absolute paths (e.g., "/Applications", "/System/Library")
|
||||||
|
SEARCH_FOLDERS = [
|
||||||
|
"Desktop",
|
||||||
|
"Downloads",
|
||||||
|
"Documents",
|
||||||
|
"Music",
|
||||||
|
"Pictures",
|
||||||
|
"Movies",
|
||||||
|
# "Library", # Uncomment to include
|
||||||
|
# "/Applications", # Absolute path example
|
||||||
|
# "Code/Projects", # Subfolder example
|
||||||
|
# Add any other folders here
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_serializable(obj):
|
||||||
|
"""Convert NS objects to Python serializable types"""
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle NSDate
|
||||||
|
if hasattr(obj, "timeIntervalSince1970"):
|
||||||
|
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
|
||||||
|
|
||||||
|
# Handle NSArray
|
||||||
|
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
|
||||||
|
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
|
||||||
|
|
||||||
|
# Convert to string
|
||||||
|
try:
|
||||||
|
return str(obj)
|
||||||
|
except Exception:
|
||||||
|
return repr(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
|
||||||
|
"""
|
||||||
|
Dump Spotlight data using public.item predicate
|
||||||
|
"""
|
||||||
|
# Build full paths from SEARCH_FOLDERS
|
||||||
|
import os
|
||||||
|
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
search_paths = []
|
||||||
|
|
||||||
|
print("Search locations:")
|
||||||
|
for folder in SEARCH_FOLDERS:
|
||||||
|
# Check if it's an absolute path or relative
|
||||||
|
if folder.startswith("/"):
|
||||||
|
full_path = folder
|
||||||
|
else:
|
||||||
|
full_path = os.path.join(home_dir, folder)
|
||||||
|
|
||||||
|
if os.path.exists(full_path):
|
||||||
|
search_paths.append(full_path)
|
||||||
|
print(f" ✓ {full_path}")
|
||||||
|
else:
|
||||||
|
print(f" ✗ {full_path} (not found)")
|
||||||
|
|
||||||
|
if not search_paths:
|
||||||
|
print("No valid search paths found!")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
|
||||||
|
|
||||||
|
# Create query with public.item predicate
|
||||||
|
query = NSMetadataQuery.alloc().init()
|
||||||
|
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
|
||||||
|
query.setPredicate_(predicate)
|
||||||
|
|
||||||
|
# Set search scopes to our specific folders
|
||||||
|
query.setSearchScopes_(search_paths)
|
||||||
|
|
||||||
|
print("Starting query...")
|
||||||
|
query.startQuery()
|
||||||
|
|
||||||
|
# Wait for gathering to complete
|
||||||
|
run_loop = NSRunLoop.currentRunLoop()
|
||||||
|
print("Gathering results...")
|
||||||
|
|
||||||
|
# Let it gather for a few seconds
|
||||||
|
for i in range(50): # 5 seconds max
|
||||||
|
run_loop.runMode_beforeDate_(
|
||||||
|
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||||
|
)
|
||||||
|
# Check gathering status periodically
|
||||||
|
if i % 10 == 0:
|
||||||
|
current_count = query.resultCount()
|
||||||
|
if current_count > 0:
|
||||||
|
print(f" Found {current_count} items so far...")
|
||||||
|
|
||||||
|
# Continue while still gathering (up to 2 more seconds)
|
||||||
|
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
|
||||||
|
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
|
||||||
|
run_loop.runMode_beforeDate_(
|
||||||
|
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||||
|
)
|
||||||
|
|
||||||
|
query.stopQuery()
|
||||||
|
|
||||||
|
total_results = query.resultCount()
|
||||||
|
print(f"Found {total_results} total items")
|
||||||
|
|
||||||
|
if total_results == 0:
|
||||||
|
print("No results found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Process items
|
||||||
|
items_to_process = min(total_results, max_items)
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# ONLY relevant attributes for vector embeddings
|
||||||
|
# These provide essential context for semantic search without bloat
|
||||||
|
attributes = [
|
||||||
|
"kMDItemPath", # Full path for file retrieval
|
||||||
|
"kMDItemFSName", # Filename for display & embedding
|
||||||
|
"kMDItemFSSize", # Size for filtering/ranking
|
||||||
|
"kMDItemContentType", # File type for categorization
|
||||||
|
"kMDItemKind", # Human-readable type for embedding
|
||||||
|
"kMDItemFSCreationDate", # Temporal context
|
||||||
|
"kMDItemFSContentChangeDate", # Recency for ranking
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Processing {items_to_process} items...")
|
||||||
|
|
||||||
|
for i in range(items_to_process):
|
||||||
|
try:
|
||||||
|
item = query.resultAtIndex_(i)
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# Extract ONLY the relevant attributes
|
||||||
|
for attr in attributes:
|
||||||
|
try:
|
||||||
|
value = item.valueForAttribute_(attr)
|
||||||
|
if value is not None:
|
||||||
|
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
|
||||||
|
clean_key = attr.replace("kMDItem", "").replace("FS", "")
|
||||||
|
metadata[clean_key] = convert_to_serializable(value)
|
||||||
|
except (AttributeError, ValueError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only add if we have at least a path
|
||||||
|
if metadata.get("Path"):
|
||||||
|
results.append(metadata)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing item {i}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save to JSON
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
print(f"\n✓ Saved {len(results)} items to {output_file}")
|
||||||
|
|
||||||
|
# Show summary
|
||||||
|
print("\nSample items:")
|
||||||
|
import os
|
||||||
|
|
||||||
|
home_dir = os.path.expanduser("~")
|
||||||
|
|
||||||
|
for i, item in enumerate(results[:3]):
|
||||||
|
print(f"\n[Item {i + 1}]")
|
||||||
|
print(f" Path: {item.get('Path', 'N/A')}")
|
||||||
|
print(f" Name: {item.get('Name', 'N/A')}")
|
||||||
|
print(f" Type: {item.get('ContentType', 'N/A')}")
|
||||||
|
print(f" Kind: {item.get('Kind', 'N/A')}")
|
||||||
|
|
||||||
|
# Handle size properly
|
||||||
|
size = item.get("Size")
|
||||||
|
if size:
|
||||||
|
try:
|
||||||
|
size_int = int(size)
|
||||||
|
if size_int > 1024 * 1024:
|
||||||
|
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
|
||||||
|
elif size_int > 1024:
|
||||||
|
print(f" Size: {size_int / 1024:.2f} KB")
|
||||||
|
else:
|
||||||
|
print(f" Size: {size_int} bytes")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
print(f" Size: {size}")
|
||||||
|
|
||||||
|
# Show dates
|
||||||
|
if "CreationDate" in item:
|
||||||
|
print(f" Created: {item['CreationDate']}")
|
||||||
|
if "ContentChangeDate" in item:
|
||||||
|
print(f" Modified: {item['ContentChangeDate']}")
|
||||||
|
|
||||||
|
# Count by type
|
||||||
|
type_counts = {}
|
||||||
|
for item in results:
|
||||||
|
content_type = item.get("ContentType", "unknown")
|
||||||
|
type_counts[content_type] = type_counts.get(content_type, 0) + 1
|
||||||
|
|
||||||
|
print(f"\nTotal items saved: {len(results)}")
|
||||||
|
|
||||||
|
if type_counts:
|
||||||
|
print("\nTop content types:")
|
||||||
|
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||||
|
print(f" {ct}: {count} items")
|
||||||
|
|
||||||
|
# Count by folder
|
||||||
|
folder_counts = {}
|
||||||
|
for item in results:
|
||||||
|
path = item.get("Path", "")
|
||||||
|
for folder in SEARCH_FOLDERS:
|
||||||
|
# Build the full folder path
|
||||||
|
if folder.startswith("/"):
|
||||||
|
folder_path = folder
|
||||||
|
else:
|
||||||
|
folder_path = os.path.join(home_dir, folder)
|
||||||
|
|
||||||
|
if path.startswith(folder_path):
|
||||||
|
folder_counts[folder] = folder_counts.get(folder, 0) + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
if folder_counts:
|
||||||
|
print("\nItems by location:")
|
||||||
|
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
|
||||||
|
print(f" {folder}: {count} items")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse arguments
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
try:
|
||||||
|
max_items = int(sys.argv[1])
|
||||||
|
except ValueError:
|
||||||
|
print("Usage: python spot.py [number_of_items]")
|
||||||
|
print("Default: 10 items")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
max_items = 10
|
||||||
|
|
||||||
|
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
|
||||||
|
|
||||||
|
# Run dump
|
||||||
|
dump_spotlight_data(max_items=max_items, output_file=output_file)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
apps/slack_data/__init__.py
Normal file
1
apps/slack_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Slack MCP data integration for LEANN
|
||||||
511
apps/slack_data/slack_mcp_reader.py
Normal file
511
apps/slack_data/slack_mcp_reader.py
Normal file
@@ -0,0 +1,511 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Slack MCP Reader for LEANN
|
||||||
|
|
||||||
|
This module provides functionality to connect to Slack MCP servers and fetch message data
|
||||||
|
for indexing in LEANN. It supports various Slack MCP server implementations and provides
|
||||||
|
flexible message processing options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SlackMCPReader:
|
||||||
|
"""
|
||||||
|
Reader for Slack data via MCP (Model Context Protocol) servers.
|
||||||
|
|
||||||
|
This class connects to Slack MCP servers to fetch message data and convert it
|
||||||
|
into a format suitable for LEANN indexing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_server_command: str,
|
||||||
|
workspace_name: Optional[str] = None,
|
||||||
|
concatenate_conversations: bool = True,
|
||||||
|
max_messages_per_conversation: int = 100,
|
||||||
|
max_retries: int = 5,
|
||||||
|
retry_delay: float = 2.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Slack MCP Reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
|
||||||
|
workspace_name: Optional workspace name to filter messages
|
||||||
|
concatenate_conversations: Whether to group messages by channel/thread
|
||||||
|
max_messages_per_conversation: Maximum messages to include per conversation
|
||||||
|
max_retries: Maximum number of retries for failed operations
|
||||||
|
retry_delay: Initial delay between retries in seconds
|
||||||
|
"""
|
||||||
|
self.mcp_server_command = mcp_server_command
|
||||||
|
self.workspace_name = workspace_name
|
||||||
|
self.concatenate_conversations = concatenate_conversations
|
||||||
|
self.max_messages_per_conversation = max_messages_per_conversation
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.retry_delay = retry_delay
|
||||||
|
self.mcp_process = None
|
||||||
|
|
||||||
|
async def start_mcp_server(self):
|
||||||
|
"""Start the MCP server process."""
|
||||||
|
try:
|
||||||
|
self.mcp_process = await asyncio.create_subprocess_exec(
|
||||||
|
*self.mcp_server_command.split(),
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start MCP server: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stop_mcp_server(self):
|
||||||
|
"""Stop the MCP server process."""
|
||||||
|
if self.mcp_process:
|
||||||
|
self.mcp_process.terminate()
|
||||||
|
await self.mcp_process.wait()
|
||||||
|
logger.info("Stopped MCP server")
|
||||||
|
|
||||||
|
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Send a request to the MCP server and get response."""
|
||||||
|
if not self.mcp_process:
|
||||||
|
raise RuntimeError("MCP server not started")
|
||||||
|
|
||||||
|
request_json = json.dumps(request) + "\n"
|
||||||
|
self.mcp_process.stdin.write(request_json.encode())
|
||||||
|
await self.mcp_process.stdin.drain()
|
||||||
|
|
||||||
|
response_line = await self.mcp_process.stdout.readline()
|
||||||
|
if not response_line:
|
||||||
|
raise RuntimeError("No response from MCP server")
|
||||||
|
|
||||||
|
return json.loads(response_line.decode().strip())
|
||||||
|
|
||||||
|
async def initialize_mcp_connection(self):
|
||||||
|
"""Initialize the MCP connection."""
|
||||||
|
init_request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self.send_mcp_request(init_request)
|
||||||
|
if "error" in response:
|
||||||
|
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
||||||
|
|
||||||
|
logger.info("MCP connection initialized successfully")
|
||||||
|
|
||||||
|
async def list_available_tools(self) -> list[dict[str, Any]]:
|
||||||
|
"""List available tools from the MCP server."""
|
||||||
|
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
||||||
|
|
||||||
|
response = await self.send_mcp_request(list_request)
|
||||||
|
if "error" in response:
|
||||||
|
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
||||||
|
|
||||||
|
return response.get("result", {}).get("tools", [])
|
||||||
|
|
||||||
|
def _is_cache_sync_error(self, error: dict) -> bool:
|
||||||
|
"""Check if the error is related to users cache not being ready."""
|
||||||
|
if isinstance(error, dict):
|
||||||
|
message = error.get("message", "").lower()
|
||||||
|
return (
|
||||||
|
"users cache is not ready" in message or "sync process is still running" in message
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _retry_with_backoff(self, func, *args, **kwargs):
|
||||||
|
"""Retry a function with exponential backoff, especially for cache sync issues."""
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
# Check if this is a cache sync error
|
||||||
|
error_dict = {}
|
||||||
|
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
|
||||||
|
error_dict = e.args[0]
|
||||||
|
elif "Failed to fetch messages" in str(e):
|
||||||
|
# Try to extract error from the exception message
|
||||||
|
import re
|
||||||
|
|
||||||
|
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
error_dict = ast.literal_eval(match.group(1))
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Try alternative format
|
||||||
|
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
error_dict = ast.literal_eval(match.group(1))
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._is_cache_sync_error(error_dict):
|
||||||
|
if attempt < self.max_retries:
|
||||||
|
delay = self.retry_delay * (2**attempt) # Exponential backoff
|
||||||
|
logger.info(
|
||||||
|
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Cache sync still not ready after {self.max_retries} retries, giving up"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Not a cache sync error, don't retry
|
||||||
|
break
|
||||||
|
|
||||||
|
# If we get here, all retries failed or it's not a retryable error
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
async def fetch_slack_messages(
|
||||||
|
self, channel: Optional[str] = None, limit: int = 100
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: Optional channel name to filter messages
|
||||||
|
limit: Maximum number of messages to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message dictionaries
|
||||||
|
"""
|
||||||
|
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
|
||||||
|
|
||||||
|
async def _fetch_slack_messages_impl(
|
||||||
|
self, channel: Optional[str] = None, limit: int = 100
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Internal implementation of fetch_slack_messages without retry logic.
|
||||||
|
"""
|
||||||
|
# This is a generic implementation - specific MCP servers may have different tool names
|
||||||
|
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
|
||||||
|
|
||||||
|
tools = await self.list_available_tools()
|
||||||
|
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
|
||||||
|
message_tool = None
|
||||||
|
|
||||||
|
# Look for a tool that can fetch messages - prioritize conversations_history
|
||||||
|
message_tool = None
|
||||||
|
|
||||||
|
# First, try to find conversations_history specifically
|
||||||
|
for tool in tools:
|
||||||
|
tool_name = tool.get("name", "").lower()
|
||||||
|
if "conversations_history" in tool_name:
|
||||||
|
message_tool = tool
|
||||||
|
logger.info(f"Found conversations_history tool: {tool}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# If not found, look for other message-fetching tools
|
||||||
|
if not message_tool:
|
||||||
|
for tool in tools:
|
||||||
|
tool_name = tool.get("name", "").lower()
|
||||||
|
if any(
|
||||||
|
keyword in tool_name
|
||||||
|
for keyword in ["conversations_search", "message", "history"]
|
||||||
|
):
|
||||||
|
message_tool = tool
|
||||||
|
break
|
||||||
|
|
||||||
|
if not message_tool:
|
||||||
|
raise RuntimeError("No message fetching tool found in MCP server")
|
||||||
|
|
||||||
|
# Prepare tool call parameters
|
||||||
|
tool_params = {"limit": "180d"} # Use 180 days to get older messages
|
||||||
|
if channel:
|
||||||
|
# For conversations_history, use channel_id parameter
|
||||||
|
if message_tool["name"] == "conversations_history":
|
||||||
|
tool_params["channel_id"] = channel
|
||||||
|
else:
|
||||||
|
# Try common parameter names for channel specification
|
||||||
|
for param_name in ["channel", "channel_id", "channel_name"]:
|
||||||
|
tool_params[param_name] = channel
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"Tool parameters: {tool_params}")
|
||||||
|
|
||||||
|
fetch_request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 3,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {"name": message_tool["name"], "arguments": tool_params},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self.send_mcp_request(fetch_request)
|
||||||
|
if "error" in response:
|
||||||
|
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
|
||||||
|
|
||||||
|
# Extract messages from response - format may vary by MCP server
|
||||||
|
result = response.get("result", {})
|
||||||
|
if "content" in result and isinstance(result["content"], list):
|
||||||
|
# Some MCP servers return content as a list
|
||||||
|
content = result["content"][0] if result["content"] else {}
|
||||||
|
if "text" in content:
|
||||||
|
try:
|
||||||
|
messages = json.loads(content["text"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
||||||
|
messages = self._parse_csv_messages(content["text"], channel)
|
||||||
|
else:
|
||||||
|
messages = result["content"]
|
||||||
|
else:
|
||||||
|
# Direct message format
|
||||||
|
messages = result.get("messages", [result])
|
||||||
|
|
||||||
|
return messages if isinstance(messages, list) else [messages]
|
||||||
|
|
||||||
|
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
|
||||||
|
"""Parse CSV format messages from Slack MCP server."""
|
||||||
|
import csv
|
||||||
|
import io
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
try:
|
||||||
|
# Split by lines and process each line as a CSV row
|
||||||
|
lines = csv_text.strip().split("\n")
|
||||||
|
if not lines:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Skip header line if it exists
|
||||||
|
start_idx = 0
|
||||||
|
if lines[0].startswith("MsgID,UserID,UserName"):
|
||||||
|
start_idx = 1
|
||||||
|
|
||||||
|
for line in lines[start_idx:]:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse CSV line
|
||||||
|
reader = csv.reader(io.StringIO(line))
|
||||||
|
try:
|
||||||
|
row = next(reader)
|
||||||
|
if len(row) >= 7: # Ensure we have enough columns
|
||||||
|
message = {
|
||||||
|
"ts": row[0],
|
||||||
|
"user": row[1],
|
||||||
|
"username": row[2],
|
||||||
|
"real_name": row[3],
|
||||||
|
"channel": row[4],
|
||||||
|
"thread_ts": row[5],
|
||||||
|
"text": row[6],
|
||||||
|
"time": row[7] if len(row) > 7 else "",
|
||||||
|
"reactions": row[8] if len(row) > 8 else "",
|
||||||
|
"cursor": row[9] if len(row) > 9 else "",
|
||||||
|
}
|
||||||
|
messages.append(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse CSV messages: {e}")
|
||||||
|
# Fallback: treat entire text as one message
|
||||||
|
messages = [{"text": csv_text, "channel": channel or "unknown"}]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _format_message(self, message: dict[str, Any]) -> str:
|
||||||
|
"""Format a single message for indexing."""
|
||||||
|
text = message.get("text", "")
|
||||||
|
user = message.get("user", message.get("username", "Unknown"))
|
||||||
|
channel = message.get("channel", message.get("channel_name", "Unknown"))
|
||||||
|
timestamp = message.get("ts", message.get("timestamp", ""))
|
||||||
|
|
||||||
|
# Format timestamp if available
|
||||||
|
formatted_time = ""
|
||||||
|
if timestamp:
|
||||||
|
try:
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
if isinstance(timestamp, str) and "." in timestamp:
|
||||||
|
dt = datetime.datetime.fromtimestamp(float(timestamp))
|
||||||
|
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
elif isinstance(timestamp, (int, float)):
|
||||||
|
dt = datetime.datetime.fromtimestamp(timestamp)
|
||||||
|
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
else:
|
||||||
|
formatted_time = str(timestamp)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
formatted_time = str(timestamp)
|
||||||
|
|
||||||
|
# Build formatted message
|
||||||
|
parts = []
|
||||||
|
if channel:
|
||||||
|
parts.append(f"Channel: #{channel}")
|
||||||
|
if user:
|
||||||
|
parts.append(f"User: {user}")
|
||||||
|
if formatted_time:
|
||||||
|
parts.append(f"Time: {formatted_time}")
|
||||||
|
if text:
|
||||||
|
parts.append(f"Message: {text}")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _create_concatenated_content(self, messages: list[dict[str, Any]], channel: str) -> str:
|
||||||
|
"""Create concatenated content from multiple messages in a channel."""
|
||||||
|
if not messages:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Sort messages by timestamp if available
|
||||||
|
try:
|
||||||
|
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass # Keep original order if timestamps aren't numeric
|
||||||
|
|
||||||
|
# Limit messages per conversation
|
||||||
|
if len(messages) > self.max_messages_per_conversation:
|
||||||
|
messages = messages[-self.max_messages_per_conversation :]
|
||||||
|
|
||||||
|
# Create header
|
||||||
|
content_parts = [
|
||||||
|
f"Slack Channel: #{channel}",
|
||||||
|
f"Message Count: {len(messages)}",
|
||||||
|
f"Workspace: {self.workspace_name or 'Unknown'}",
|
||||||
|
"=" * 50,
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add messages
|
||||||
|
for message in messages:
|
||||||
|
formatted_msg = self._format_message(message)
|
||||||
|
if formatted_msg.strip():
|
||||||
|
content_parts.append(formatted_msg)
|
||||||
|
content_parts.append("-" * 30)
|
||||||
|
content_parts.append("")
|
||||||
|
|
||||||
|
return "\n".join(content_parts)
|
||||||
|
|
||||||
|
async def get_all_channels(self) -> list[str]:
|
||||||
|
"""Get list of all available channels."""
|
||||||
|
try:
|
||||||
|
channels_list_request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 4,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {"name": "channels_list", "arguments": {}},
|
||||||
|
}
|
||||||
|
channels_response = await self.send_mcp_request(channels_list_request)
|
||||||
|
if "result" in channels_response:
|
||||||
|
result = channels_response["result"]
|
||||||
|
if "content" in result and isinstance(result["content"], list):
|
||||||
|
content = result["content"][0] if result["content"] else {}
|
||||||
|
if "text" in content:
|
||||||
|
# Parse the channels from the response
|
||||||
|
channels = []
|
||||||
|
lines = content["text"].split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if line.strip() and ("#" in line or "C" in line[:10]):
|
||||||
|
# Extract channel ID or name
|
||||||
|
parts = line.split()
|
||||||
|
for part in parts:
|
||||||
|
if part.startswith("C") and len(part) > 5:
|
||||||
|
channels.append(part)
|
||||||
|
elif part.startswith("#"):
|
||||||
|
channels.append(part[1:]) # Remove #
|
||||||
|
logger.info(f"Found {len(channels)} channels: {channels}")
|
||||||
|
return channels
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to get channels list: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
|
||||||
|
"""
|
||||||
|
Read Slack data and return formatted text chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of formatted text chunks ready for LEANN indexing
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.start_mcp_server()
|
||||||
|
await self.initialize_mcp_connection()
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
|
||||||
|
if channels:
|
||||||
|
# Fetch specific channels
|
||||||
|
for channel in channels:
|
||||||
|
try:
|
||||||
|
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||||
|
if messages:
|
||||||
|
if self.concatenate_conversations:
|
||||||
|
text_content = self._create_concatenated_content(messages, channel)
|
||||||
|
if text_content.strip():
|
||||||
|
all_texts.append(text_content)
|
||||||
|
else:
|
||||||
|
# Process individual messages
|
||||||
|
for message in messages:
|
||||||
|
formatted_msg = self._format_message(message)
|
||||||
|
if formatted_msg.strip():
|
||||||
|
all_texts.append(formatted_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Fetch from all available channels
|
||||||
|
logger.info("Fetching from all available channels...")
|
||||||
|
all_channels = await self.get_all_channels()
|
||||||
|
|
||||||
|
if not all_channels:
|
||||||
|
# Fallback to common channel names if we can't get the list
|
||||||
|
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
|
||||||
|
logger.info(f"Using fallback channels: {all_channels}")
|
||||||
|
|
||||||
|
for channel in all_channels:
|
||||||
|
try:
|
||||||
|
logger.info(f"Searching channel: {channel}")
|
||||||
|
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||||
|
if messages:
|
||||||
|
if self.concatenate_conversations:
|
||||||
|
text_content = self._create_concatenated_content(messages, channel)
|
||||||
|
if text_content.strip():
|
||||||
|
all_texts.append(text_content)
|
||||||
|
else:
|
||||||
|
# Process individual messages
|
||||||
|
for message in messages:
|
||||||
|
formatted_msg = self._format_message(message)
|
||||||
|
if formatted_msg.strip():
|
||||||
|
all_texts.append(formatted_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.stop_mcp_server()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Async context manager entry."""
|
||||||
|
await self.start_mcp_server()
|
||||||
|
await self.initialize_mcp_connection()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Async context manager exit."""
|
||||||
|
await self.stop_mcp_server()
|
||||||
227
apps/slack_rag.py
Normal file
227
apps/slack_rag.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Slack RAG Application with MCP Support
|
||||||
|
|
||||||
|
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
|
||||||
|
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from apps.base_rag_example import BaseRAGExample
|
||||||
|
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
||||||
|
|
||||||
|
|
||||||
|
class SlackMCPRAG(BaseRAGExample):
|
||||||
|
"""
|
||||||
|
RAG application for Slack messages via MCP servers.
|
||||||
|
|
||||||
|
This class provides a complete RAG pipeline for Slack data, including
|
||||||
|
MCP server connection, data fetching, indexing, and interactive chat.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Slack MCP RAG",
|
||||||
|
description="RAG application for Slack messages via MCP servers",
|
||||||
|
default_index_name="slack_messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add Slack MCP-specific arguments."""
|
||||||
|
parser.add_argument(
|
||||||
|
"--mcp-server",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--workspace-name",
|
||||||
|
type=str,
|
||||||
|
help="Slack workspace name for better organization and filtering",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--channels",
|
||||||
|
nargs="+",
|
||||||
|
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--concatenate-conversations",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Group messages by channel/thread for better context (default: True)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-concatenate-conversations",
|
||||||
|
action="store_true",
|
||||||
|
help="Process individual messages instead of grouping by channel",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-messages-per-channel",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum number of messages to include per channel (default: 100)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-connection",
|
||||||
|
action="store_true",
|
||||||
|
help="Test MCP server connection and list available tools without indexing",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-retries",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Maximum number of retries for failed operations (default: 5)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--retry-delay",
|
||||||
|
type=float,
|
||||||
|
default=2.0,
|
||||||
|
help="Initial delay between retries in seconds (default: 2.0)",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_mcp_connection(self, args) -> bool:
|
||||||
|
"""Test the MCP server connection and display available tools."""
|
||||||
|
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
reader = SlackMCPReader(
|
||||||
|
mcp_server_command=args.mcp_server,
|
||||||
|
workspace_name=args.workspace_name,
|
||||||
|
concatenate_conversations=not args.no_concatenate_conversations,
|
||||||
|
max_messages_per_conversation=args.max_messages_per_channel,
|
||||||
|
max_retries=args.max_retries,
|
||||||
|
retry_delay=args.retry_delay,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with reader:
|
||||||
|
tools = await reader.list_available_tools()
|
||||||
|
|
||||||
|
print("Successfully connected to MCP server!")
|
||||||
|
print(f"Available tools ({len(tools)}):")
|
||||||
|
|
||||||
|
for i, tool in enumerate(tools, 1):
|
||||||
|
name = tool.get("name", "Unknown")
|
||||||
|
description = tool.get("description", "No description available")
|
||||||
|
print(f"\n{i}. {name}")
|
||||||
|
print(
|
||||||
|
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show input schema if available
|
||||||
|
schema = tool.get("inputSchema", {})
|
||||||
|
if schema.get("properties"):
|
||||||
|
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
||||||
|
print(
|
||||||
|
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to connect to MCP server: {e}")
|
||||||
|
print("\nTroubleshooting tips:")
|
||||||
|
print("1. Make sure the MCP server is installed and accessible")
|
||||||
|
print("2. Check if the server command is correct")
|
||||||
|
print("3. Ensure you have proper authentication/credentials configured")
|
||||||
|
print("4. Try running the MCP server command directly to test it")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load Slack messages via MCP server."""
|
||||||
|
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
||||||
|
|
||||||
|
if args.workspace_name:
|
||||||
|
print(f"Workspace: {args.workspace_name}")
|
||||||
|
|
||||||
|
# Filter out empty strings from channels
|
||||||
|
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
|
||||||
|
|
||||||
|
if channels:
|
||||||
|
print(f"Channels: {', '.join(channels)}")
|
||||||
|
else:
|
||||||
|
print("Fetching from all available channels")
|
||||||
|
|
||||||
|
concatenate = not args.no_concatenate_conversations
|
||||||
|
print(
|
||||||
|
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
reader = SlackMCPReader(
|
||||||
|
mcp_server_command=args.mcp_server,
|
||||||
|
workspace_name=args.workspace_name,
|
||||||
|
concatenate_conversations=concatenate,
|
||||||
|
max_messages_per_conversation=args.max_messages_per_channel,
|
||||||
|
max_retries=args.max_retries,
|
||||||
|
retry_delay=args.retry_delay,
|
||||||
|
)
|
||||||
|
|
||||||
|
texts = await reader.read_slack_data(channels=channels)
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
print("No messages found! This could mean:")
|
||||||
|
print("- The MCP server couldn't fetch messages")
|
||||||
|
print("- The specified channels don't exist or are empty")
|
||||||
|
print("- Authentication issues with the Slack workspace")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Successfully loaded {len(texts)} text chunks from Slack")
|
||||||
|
|
||||||
|
# Show sample of what was loaded
|
||||||
|
if texts:
|
||||||
|
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
|
||||||
|
print("\nSample content:")
|
||||||
|
print("-" * 40)
|
||||||
|
print(sample_text)
|
||||||
|
print("-" * 40)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading Slack data: {e}")
|
||||||
|
print("\nThis might be due to:")
|
||||||
|
print("- MCP server connection issues")
|
||||||
|
print("- Authentication problems")
|
||||||
|
print("- Network connectivity issues")
|
||||||
|
print("- Incorrect channel names")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main entry point with MCP connection testing."""
|
||||||
|
args = self.parser.parse_args()
|
||||||
|
|
||||||
|
# Test connection if requested
|
||||||
|
if args.test_connection:
|
||||||
|
success = await self.test_mcp_connection(args)
|
||||||
|
if not success:
|
||||||
|
return
|
||||||
|
print(
|
||||||
|
"MCP server is working! You can now run without --test-connection to start indexing."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run the standard RAG pipeline
|
||||||
|
await super().run()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main entry point for the Slack MCP RAG application."""
|
||||||
|
app = SlackMCPRAG()
|
||||||
|
await app.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
1
apps/twitter_data/__init__.py
Normal file
1
apps/twitter_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# Twitter MCP data integration for LEANN
|
||||||
295
apps/twitter_data/twitter_mcp_reader.py
Normal file
295
apps/twitter_data/twitter_mcp_reader.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Twitter MCP Reader for LEANN
|
||||||
|
|
||||||
|
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
|
||||||
|
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
|
||||||
|
flexible bookmark processing options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TwitterMCPReader:
|
||||||
|
"""
|
||||||
|
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
|
||||||
|
|
||||||
|
This class connects to Twitter MCP servers to fetch bookmark data and convert it
|
||||||
|
into a format suitable for LEANN indexing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_server_command: str,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
include_tweet_content: bool = True,
|
||||||
|
include_metadata: bool = True,
|
||||||
|
max_bookmarks: int = 1000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the Twitter MCP Reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
|
||||||
|
username: Optional Twitter username to filter bookmarks
|
||||||
|
include_tweet_content: Whether to include full tweet content
|
||||||
|
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
|
||||||
|
max_bookmarks: Maximum number of bookmarks to fetch
|
||||||
|
"""
|
||||||
|
self.mcp_server_command = mcp_server_command
|
||||||
|
self.username = username
|
||||||
|
self.include_tweet_content = include_tweet_content
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
self.max_bookmarks = max_bookmarks
|
||||||
|
self.mcp_process = None
|
||||||
|
|
||||||
|
async def start_mcp_server(self):
|
||||||
|
"""Start the MCP server process."""
|
||||||
|
try:
|
||||||
|
self.mcp_process = await asyncio.create_subprocess_exec(
|
||||||
|
*self.mcp_server_command.split(),
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start MCP server: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def stop_mcp_server(self):
|
||||||
|
"""Stop the MCP server process."""
|
||||||
|
if self.mcp_process:
|
||||||
|
self.mcp_process.terminate()
|
||||||
|
await self.mcp_process.wait()
|
||||||
|
logger.info("Stopped MCP server")
|
||||||
|
|
||||||
|
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Send a request to the MCP server and get response."""
|
||||||
|
if not self.mcp_process:
|
||||||
|
raise RuntimeError("MCP server not started")
|
||||||
|
|
||||||
|
request_json = json.dumps(request) + "\n"
|
||||||
|
self.mcp_process.stdin.write(request_json.encode())
|
||||||
|
await self.mcp_process.stdin.drain()
|
||||||
|
|
||||||
|
response_line = await self.mcp_process.stdout.readline()
|
||||||
|
if not response_line:
|
||||||
|
raise RuntimeError("No response from MCP server")
|
||||||
|
|
||||||
|
return json.loads(response_line.decode().strip())
|
||||||
|
|
||||||
|
async def initialize_mcp_connection(self):
|
||||||
|
"""Initialize the MCP connection."""
|
||||||
|
init_request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 1,
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self.send_mcp_request(init_request)
|
||||||
|
if "error" in response:
|
||||||
|
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
||||||
|
|
||||||
|
logger.info("MCP connection initialized successfully")
|
||||||
|
|
||||||
|
async def list_available_tools(self) -> list[dict[str, Any]]:
|
||||||
|
"""List available tools from the MCP server."""
|
||||||
|
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
||||||
|
|
||||||
|
response = await self.send_mcp_request(list_request)
|
||||||
|
if "error" in response:
|
||||||
|
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
||||||
|
|
||||||
|
return response.get("result", {}).get("tools", [])
|
||||||
|
|
||||||
|
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Fetch Twitter bookmarks using MCP tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum number of bookmarks to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of bookmark dictionaries
|
||||||
|
"""
|
||||||
|
tools = await self.list_available_tools()
|
||||||
|
bookmark_tool = None
|
||||||
|
|
||||||
|
# Look for a tool that can fetch bookmarks
|
||||||
|
for tool in tools:
|
||||||
|
tool_name = tool.get("name", "").lower()
|
||||||
|
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
|
||||||
|
bookmark_tool = tool
|
||||||
|
break
|
||||||
|
|
||||||
|
if not bookmark_tool:
|
||||||
|
raise RuntimeError("No bookmark fetching tool found in MCP server")
|
||||||
|
|
||||||
|
# Prepare tool call parameters
|
||||||
|
tool_params = {}
|
||||||
|
if limit or self.max_bookmarks:
|
||||||
|
tool_params["limit"] = limit or self.max_bookmarks
|
||||||
|
if self.username:
|
||||||
|
tool_params["username"] = self.username
|
||||||
|
|
||||||
|
fetch_request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": 3,
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self.send_mcp_request(fetch_request)
|
||||||
|
if "error" in response:
|
||||||
|
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
|
||||||
|
|
||||||
|
# Extract bookmarks from response
|
||||||
|
result = response.get("result", {})
|
||||||
|
if "content" in result and isinstance(result["content"], list):
|
||||||
|
content = result["content"][0] if result["content"] else {}
|
||||||
|
if "text" in content:
|
||||||
|
try:
|
||||||
|
bookmarks = json.loads(content["text"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If not JSON, treat as plain text
|
||||||
|
bookmarks = [{"text": content["text"], "source": "twitter"}]
|
||||||
|
else:
|
||||||
|
bookmarks = result["content"]
|
||||||
|
else:
|
||||||
|
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
|
||||||
|
|
||||||
|
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
|
||||||
|
|
||||||
|
def _format_bookmark(self, bookmark: dict[str, Any]) -> str:
|
||||||
|
"""Format a single bookmark for indexing."""
|
||||||
|
# Extract tweet information
|
||||||
|
text = bookmark.get("text", bookmark.get("content", ""))
|
||||||
|
author = bookmark.get(
|
||||||
|
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
|
||||||
|
)
|
||||||
|
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
|
||||||
|
url = bookmark.get("url", bookmark.get("tweet_url", ""))
|
||||||
|
|
||||||
|
# Extract metadata if available
|
||||||
|
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
|
||||||
|
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
|
||||||
|
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
|
||||||
|
|
||||||
|
# Build formatted bookmark
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# Header
|
||||||
|
parts.append("=== Twitter Bookmark ===")
|
||||||
|
|
||||||
|
if author:
|
||||||
|
parts.append(f"Author: @{author}")
|
||||||
|
|
||||||
|
if timestamp:
|
||||||
|
# Format timestamp if it's a standard format
|
||||||
|
try:
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
if "T" in str(timestamp): # ISO format
|
||||||
|
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||||
|
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
else:
|
||||||
|
formatted_time = str(timestamp)
|
||||||
|
parts.append(f"Date: {formatted_time}")
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
parts.append(f"Date: {timestamp}")
|
||||||
|
|
||||||
|
if url:
|
||||||
|
parts.append(f"URL: {url}")
|
||||||
|
|
||||||
|
# Tweet content
|
||||||
|
if text and self.include_tweet_content:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("Content:")
|
||||||
|
parts.append(text)
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
if self.include_metadata and any([likes, retweets, replies]):
|
||||||
|
parts.append("")
|
||||||
|
parts.append("Engagement:")
|
||||||
|
if likes:
|
||||||
|
parts.append(f" Likes: {likes}")
|
||||||
|
if retweets:
|
||||||
|
parts.append(f" Retweets: {retweets}")
|
||||||
|
if replies:
|
||||||
|
parts.append(f" Replies: {replies}")
|
||||||
|
|
||||||
|
# Extract hashtags and mentions if available
|
||||||
|
hashtags = bookmark.get("hashtags", [])
|
||||||
|
mentions = bookmark.get("mentions", [])
|
||||||
|
|
||||||
|
if hashtags or mentions:
|
||||||
|
parts.append("")
|
||||||
|
if hashtags:
|
||||||
|
parts.append(f"Hashtags: {', '.join(hashtags)}")
|
||||||
|
if mentions:
|
||||||
|
parts.append(f"Mentions: {', '.join(mentions)}")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
async def read_twitter_bookmarks(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Read Twitter bookmark data and return formatted text chunks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of formatted text chunks ready for LEANN indexing
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self.start_mcp_server()
|
||||||
|
await self.initialize_mcp_connection()
|
||||||
|
|
||||||
|
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
|
||||||
|
if self.username:
|
||||||
|
print(f"Filtering for user: @{self.username}")
|
||||||
|
|
||||||
|
bookmarks = await self.fetch_twitter_bookmarks()
|
||||||
|
|
||||||
|
if not bookmarks:
|
||||||
|
print("No bookmarks found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"Processing {len(bookmarks)} bookmarks...")
|
||||||
|
|
||||||
|
all_texts = []
|
||||||
|
processed_count = 0
|
||||||
|
|
||||||
|
for bookmark in bookmarks:
|
||||||
|
try:
|
||||||
|
formatted_bookmark = self._format_bookmark(bookmark)
|
||||||
|
if formatted_bookmark.strip():
|
||||||
|
all_texts.append(formatted_bookmark)
|
||||||
|
processed_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to format bookmark: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"Successfully processed {processed_count} bookmarks")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.stop_mcp_server()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Async context manager entry."""
|
||||||
|
await self.start_mcp_server()
|
||||||
|
await self.initialize_mcp_connection()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Async context manager exit."""
|
||||||
|
await self.stop_mcp_server()
|
||||||
195
apps/twitter_rag.py
Normal file
195
apps/twitter_rag.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Twitter RAG Application with MCP Support
|
||||||
|
|
||||||
|
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
|
||||||
|
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from apps.base_rag_example import BaseRAGExample
|
||||||
|
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
||||||
|
|
||||||
|
|
||||||
|
class TwitterMCPRAG(BaseRAGExample):
|
||||||
|
"""
|
||||||
|
RAG application for Twitter bookmarks via MCP servers.
|
||||||
|
|
||||||
|
This class provides a complete RAG pipeline for Twitter bookmark data, including
|
||||||
|
MCP server connection, data fetching, indexing, and interactive chat.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
name="Twitter MCP RAG",
|
||||||
|
description="RAG application for Twitter bookmarks via MCP servers",
|
||||||
|
default_index_name="twitter_bookmarks",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||||
|
"""Add Twitter MCP-specific arguments."""
|
||||||
|
parser.add_argument(
|
||||||
|
"--mcp-server",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--username", type=str, help="Twitter username to filter bookmarks (without @)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-bookmarks",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Maximum number of bookmarks to fetch (default: 1000)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-tweet-content",
|
||||||
|
action="store_true",
|
||||||
|
help="Exclude tweet content, only include metadata",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-metadata",
|
||||||
|
action="store_true",
|
||||||
|
help="Exclude engagement metadata (likes, retweets, etc.)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-connection",
|
||||||
|
action="store_true",
|
||||||
|
help="Test MCP server connection and list available tools without indexing",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_mcp_connection(self, args) -> bool:
|
||||||
|
"""Test the MCP server connection and display available tools."""
|
||||||
|
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
reader = TwitterMCPReader(
|
||||||
|
mcp_server_command=args.mcp_server,
|
||||||
|
username=args.username,
|
||||||
|
include_tweet_content=not args.no_tweet_content,
|
||||||
|
include_metadata=not args.no_metadata,
|
||||||
|
max_bookmarks=args.max_bookmarks,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with reader:
|
||||||
|
tools = await reader.list_available_tools()
|
||||||
|
|
||||||
|
print("\n✅ Successfully connected to MCP server!")
|
||||||
|
print(f"Available tools ({len(tools)}):")
|
||||||
|
|
||||||
|
for i, tool in enumerate(tools, 1):
|
||||||
|
name = tool.get("name", "Unknown")
|
||||||
|
description = tool.get("description", "No description available")
|
||||||
|
print(f"\n{i}. {name}")
|
||||||
|
print(
|
||||||
|
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show input schema if available
|
||||||
|
schema = tool.get("inputSchema", {})
|
||||||
|
if schema.get("properties"):
|
||||||
|
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
||||||
|
print(
|
||||||
|
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Failed to connect to MCP server: {e}")
|
||||||
|
print("\nTroubleshooting tips:")
|
||||||
|
print("1. Make sure the Twitter MCP server is installed and accessible")
|
||||||
|
print("2. Check if the server command is correct")
|
||||||
|
print("3. Ensure you have proper Twitter API credentials configured")
|
||||||
|
print("4. Verify your Twitter account has bookmarks to fetch")
|
||||||
|
print("5. Try running the MCP server command directly to test it")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load Twitter bookmarks via MCP server."""
|
||||||
|
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
||||||
|
|
||||||
|
if args.username:
|
||||||
|
print(f"Username filter: @{args.username}")
|
||||||
|
|
||||||
|
print(f"Max bookmarks: {args.max_bookmarks}")
|
||||||
|
print(f"Include tweet content: {not args.no_tweet_content}")
|
||||||
|
print(f"Include metadata: {not args.no_metadata}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
reader = TwitterMCPReader(
|
||||||
|
mcp_server_command=args.mcp_server,
|
||||||
|
username=args.username,
|
||||||
|
include_tweet_content=not args.no_tweet_content,
|
||||||
|
include_metadata=not args.no_metadata,
|
||||||
|
max_bookmarks=args.max_bookmarks,
|
||||||
|
)
|
||||||
|
|
||||||
|
texts = await reader.read_twitter_bookmarks()
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
print("❌ No bookmarks found! This could mean:")
|
||||||
|
print("- You don't have any bookmarks on Twitter")
|
||||||
|
print("- The MCP server couldn't access your bookmarks")
|
||||||
|
print("- Authentication issues with Twitter API")
|
||||||
|
print("- The username filter didn't match any bookmarks")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
|
||||||
|
|
||||||
|
# Show sample of what was loaded
|
||||||
|
if texts:
|
||||||
|
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
|
||||||
|
print("\nSample bookmark:")
|
||||||
|
print("-" * 50)
|
||||||
|
print(sample_text)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error loading Twitter bookmarks: {e}")
|
||||||
|
print("\nThis might be due to:")
|
||||||
|
print("- MCP server connection issues")
|
||||||
|
print("- Twitter API authentication problems")
|
||||||
|
print("- Network connectivity issues")
|
||||||
|
print("- Rate limiting from Twitter API")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Main entry point with MCP connection testing."""
|
||||||
|
args = self.parser.parse_args()
|
||||||
|
|
||||||
|
# Test connection if requested
|
||||||
|
if args.test_connection:
|
||||||
|
success = await self.test_mcp_connection(args)
|
||||||
|
if not success:
|
||||||
|
return
|
||||||
|
print(
|
||||||
|
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run the standard RAG pipeline
|
||||||
|
await super().run()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main entry point for the Twitter MCP RAG application."""
|
||||||
|
app = TwitterMCPRAG()
|
||||||
|
await app.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
189
apps/wechat_rag.py
Normal file
189
apps/wechat_rag.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
WeChat History RAG example using the unified interface.
|
||||||
|
Supports WeChat chat history export and search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add parent directory to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from base_rag_example import BaseRAGExample
|
||||||
|
|
||||||
|
from .history_data.wechat_history import WeChatHistoryReader
|
||||||
|
|
||||||
|
|
||||||
|
class WeChatRAG(BaseRAGExample):
|
||||||
|
"""RAG example for WeChat chat history."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Set default values BEFORE calling super().__init__
|
||||||
|
self.max_items_default = -1 # Match original default
|
||||||
|
self.embedding_model_default = (
|
||||||
|
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
name="WeChat History",
|
||||||
|
description="Process and query WeChat chat history with LEANN",
|
||||||
|
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_specific_arguments(self, parser):
|
||||||
|
"""Add WeChat-specific arguments."""
|
||||||
|
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--export-dir",
|
||||||
|
type=str,
|
||||||
|
default="./wechat_export",
|
||||||
|
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--force-export",
|
||||||
|
action="store_true",
|
||||||
|
help="Force re-export of WeChat data even if exports exist",
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||||
|
)
|
||||||
|
wechat_group.add_argument(
|
||||||
|
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||||
|
"""Export WeChat data using wechattweak-cli."""
|
||||||
|
print("Exporting WeChat data...")
|
||||||
|
|
||||||
|
# Check if WeChat is running
|
||||||
|
try:
|
||||||
|
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print("WeChat is not running. Please start WeChat first.")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
pass # pgrep might not be available on all systems
|
||||||
|
|
||||||
|
# Create export directory
|
||||||
|
export_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Run export command
|
||||||
|
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"Running: {' '.join(cmd)}")
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print("WeChat data exported successfully!")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"Export failed: {result.stderr}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("\nError: wechattweak-cli not found!")
|
||||||
|
print("Please install it first:")
|
||||||
|
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Export error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def load_data(self, args) -> list[str]:
|
||||||
|
"""Load WeChat history and convert to text chunks."""
|
||||||
|
# Initialize WeChat reader with export capabilities
|
||||||
|
reader = WeChatHistoryReader()
|
||||||
|
|
||||||
|
# Find existing exports or create new ones using the centralized method
|
||||||
|
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||||
|
if not export_dirs:
|
||||||
|
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||||
|
# Try to find any existing exports in common locations
|
||||||
|
export_dirs = reader.find_wechat_export_dirs()
|
||||||
|
if not export_dirs:
|
||||||
|
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Load documents from all found export directories
|
||||||
|
all_documents = []
|
||||||
|
total_processed = 0
|
||||||
|
|
||||||
|
for i, export_dir in enumerate(export_dirs):
|
||||||
|
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply max_items limit per export
|
||||||
|
max_per_export = -1
|
||||||
|
if args.max_items > 0:
|
||||||
|
remaining = args.max_items - total_processed
|
||||||
|
if remaining <= 0:
|
||||||
|
break
|
||||||
|
max_per_export = remaining
|
||||||
|
|
||||||
|
documents = reader.load_data(
|
||||||
|
wechat_export_dir=str(export_dir),
|
||||||
|
max_count=max_per_export,
|
||||||
|
concatenate_messages=True, # Enable message concatenation for better context
|
||||||
|
)
|
||||||
|
|
||||||
|
if documents:
|
||||||
|
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||||
|
all_documents.extend(documents)
|
||||||
|
total_processed += len(documents)
|
||||||
|
else:
|
||||||
|
print(f"No documents loaded from {export_dir}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {export_dir}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not all_documents:
|
||||||
|
print("No documents loaded from any source. Exiting.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||||
|
print("now starting to split into text chunks ... take some time")
|
||||||
|
|
||||||
|
# Convert to text chunks with contact information
|
||||||
|
all_texts = []
|
||||||
|
for doc in all_documents:
|
||||||
|
# Split the document into chunks
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
text_splitter = SentenceSplitter(
|
||||||
|
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||||
|
)
|
||||||
|
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
# Add contact information to each chunk
|
||||||
|
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||||
|
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||||
|
all_texts.append(text)
|
||||||
|
|
||||||
|
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||||
|
return all_texts
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Check platform
|
||||||
|
if sys.platform != "darwin":
|
||||||
|
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||||
|
print(" You can still query existing exports on other platforms\n")
|
||||||
|
|
||||||
|
# Example queries for WeChat RAG
|
||||||
|
print("\n💬 WeChat History RAG Example")
|
||||||
|
print("=" * 50)
|
||||||
|
print("\nExample queries you can try:")
|
||||||
|
print("- 'Show me conversations about travel plans'")
|
||||||
|
print("- 'Find group chats about weekend activities'")
|
||||||
|
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||||
|
print("- 'What did we discuss about the project last month?'")
|
||||||
|
print("\nNote: WeChat must be running for export to work\n")
|
||||||
|
|
||||||
|
rag = WeChatRAG()
|
||||||
|
asyncio.run(rag.run())
|
||||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
@@ -1,9 +1,24 @@
|
|||||||
# 🧪 Leann Sanity Checks
|
# 🧪 LEANN Benchmarks & Testing
|
||||||
|
|
||||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||||
|
|
||||||
## 📁 Test Files
|
## 📁 Test Files
|
||||||
|
|
||||||
|
### `diskann_vs_hnsw_speed_comparison.py`
|
||||||
|
Performance comparison between DiskANN and HNSW backends:
|
||||||
|
- ✅ **Search latency** comparison with both backends using recompute
|
||||||
|
- ✅ **Index size** and **build time** measurements
|
||||||
|
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||||
|
- ✅ **Configurable dataset sizes** for different scales
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Quick comparison with 500 docs, 10 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||||
|
|
||||||
|
# Large-scale comparison with 2000 docs, 20 queries
|
||||||
|
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||||
|
```
|
||||||
|
|
||||||
### `test_distance_functions.py`
|
### `test_distance_functions.py`
|
||||||
Tests all supported distance functions across DiskANN backend:
|
Tests all supported distance functions across DiskANN backend:
|
||||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||||
0
benchmarks/__init__.py
Normal file
0
benchmarks/__init__.py
Normal file
@@ -1,29 +1,32 @@
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from mlx_lm import load
|
from mlx_lm import load
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||||
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||||
NUM_RUNS = 10 # Number of runs to average for each batch size
|
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||||
|
|
||||||
# --- Generate Dummy Data ---
|
# --- Generate Dummy Data ---
|
||||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||||
|
|
||||||
# --- Benchmark Functions ---b
|
# --- Benchmark Functions ---b
|
||||||
|
|
||||||
|
|
||||||
def benchmark_torch(model, sentences):
|
def benchmark_torch(model, sentences):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
model.encode(sentences, convert_to_numpy=True)
|
model.encode(sentences, convert_to_numpy=True)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
return (end_time - start_time) * 1000 # Return time in ms
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
def benchmark_mlx(model, tokenizer, sentences):
|
def benchmark_mlx(model, tokenizer, sentences):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -63,6 +66,7 @@ def benchmark_mlx(model, tokenizer, sentences):
|
|||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
return (end_time - start_time) * 1000 # Return time in ms
|
return (end_time - start_time) * 1000 # Return time in ms
|
||||||
|
|
||||||
|
|
||||||
# --- Main Execution ---
|
# --- Main Execution ---
|
||||||
def main():
|
def main():
|
||||||
print("--- Initializing Models ---")
|
print("--- Initializing Models ---")
|
||||||
@@ -98,7 +102,9 @@ def main():
|
|||||||
results_torch.append(np.mean(torch_times))
|
results_torch.append(np.mean(torch_times))
|
||||||
|
|
||||||
# Benchmark MLX
|
# Benchmark MLX
|
||||||
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
|
mlx_times = [
|
||||||
|
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||||
|
]
|
||||||
results_mlx.append(np.mean(mlx_times))
|
results_mlx.append(np.mean(mlx_times))
|
||||||
|
|
||||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||||
@@ -109,10 +115,16 @@ def main():
|
|||||||
# --- Plotting ---
|
# --- Plotting ---
|
||||||
print("\n--- Generating Plot ---")
|
print("\n--- Generating Plot ---")
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
|
plt.plot(
|
||||||
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
|
BATCH_SIZES,
|
||||||
|
results_torch,
|
||||||
|
marker="o",
|
||||||
|
linestyle="-",
|
||||||
|
label=f"PyTorch ({device})",
|
||||||
|
)
|
||||||
|
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||||
|
|
||||||
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
|
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||||
plt.xlabel("Batch Size")
|
plt.xlabel("Batch Size")
|
||||||
plt.ylabel("Average Time per Batch (ms)")
|
plt.ylabel("Average Time per Batch (ms)")
|
||||||
plt.xticks(BATCH_SIZES)
|
plt.xticks(BATCH_SIZES)
|
||||||
@@ -124,5 +136,6 @@ def main():
|
|||||||
plt.savefig(output_filename)
|
plt.savefig(output_filename)
|
||||||
print(f"Plot saved to {output_filename}")
|
print(f"Plot saved to {output_filename}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
148
benchmarks/benchmark_no_recompute.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
|
def _meta_exists(index_path: str) -> bool:
|
||||||
|
p = Path(index_path)
|
||||||
|
return (p.parent / f"{p.stem}.meta.json").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
||||||
|
# if _meta_exists(index_path):
|
||||||
|
# return
|
||||||
|
kwargs = {}
|
||||||
|
if backend_name == "hnsw":
|
||||||
|
kwargs["is_compact"] = is_recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
||||||
|
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=is_recompute,
|
||||||
|
num_threads=4,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for i in range(num_docs):
|
||||||
|
builder.add_text(
|
||||||
|
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
||||||
|
)
|
||||||
|
builder.build_index(index_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _bench_group(
|
||||||
|
index_path: str,
|
||||||
|
recompute: bool,
|
||||||
|
query: str,
|
||||||
|
repeats: int,
|
||||||
|
complexity: int = 32,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> float:
|
||||||
|
# Independent searcher per group; fixed port when recompute
|
||||||
|
searcher = LeannSearcher(index_path=index_path)
|
||||||
|
|
||||||
|
# Warm-up once
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _once() -> float:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=top_k,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute,
|
||||||
|
)
|
||||||
|
return time.time() - t0
|
||||||
|
|
||||||
|
if repeats <= 1:
|
||||||
|
t = _once()
|
||||||
|
else:
|
||||||
|
vals = [_once() for _ in range(repeats)]
|
||||||
|
vals.sort()
|
||||||
|
t = vals[len(vals) // 2]
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-docs", type=int, default=5000)
|
||||||
|
parser.add_argument("--repeats", type=int, default=3)
|
||||||
|
parser.add_argument("--complexity", type=int, default=32)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
||||||
|
base.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# ---------- Build HNSW variants ----------
|
||||||
|
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
||||||
|
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
||||||
|
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
||||||
|
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Build DiskANN variants ----------
|
||||||
|
diskann_r = str(base / "diskann_r.leann")
|
||||||
|
diskann_nr = str(base / "diskann_nr.leann")
|
||||||
|
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
||||||
|
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
||||||
|
|
||||||
|
# ---------- Helpers ----------
|
||||||
|
def _size_for(prefix: str) -> int:
|
||||||
|
p = Path(prefix)
|
||||||
|
base_dir = p.parent
|
||||||
|
stem = p.stem
|
||||||
|
total = 0
|
||||||
|
for f in base_dir.iterdir():
|
||||||
|
if f.is_file() and f.name.startswith(stem):
|
||||||
|
total += f.stat().st_size
|
||||||
|
return total
|
||||||
|
|
||||||
|
# ---------- HNSW benchmark ----------
|
||||||
|
t_hnsw_r = _bench_group(
|
||||||
|
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_hnsw_nr = _bench_group(
|
||||||
|
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
size_hnsw_r = _size_for(hnsw_r)
|
||||||
|
size_hnsw_nr = _size_for(hnsw_nr)
|
||||||
|
|
||||||
|
print("Benchmark results (HNSW):")
|
||||||
|
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(
|
||||||
|
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
||||||
|
)
|
||||||
|
print(" Expectation: no-recompute should be faster but larger on disk.")
|
||||||
|
|
||||||
|
# ---------- DiskANN benchmark ----------
|
||||||
|
t_diskann_r = _bench_group(
|
||||||
|
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
||||||
|
)
|
||||||
|
t_diskann_nr = _bench_group(
|
||||||
|
diskann_nr,
|
||||||
|
False,
|
||||||
|
"DiskANN NR test doc 123",
|
||||||
|
repeats=args.repeats,
|
||||||
|
complexity=args.complexity,
|
||||||
|
)
|
||||||
|
size_diskann_r = _size_for(diskann_r)
|
||||||
|
size_diskann_nr = _size_for(diskann_nr)
|
||||||
|
|
||||||
|
print("\nBenchmark results (DiskANN):")
|
||||||
|
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
||||||
|
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
||||||
|
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
BM25 vs DiskANN Baselines
|
||||||
|
|
||||||
|
```bash
|
||||||
|
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
|
||||||
|
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
|
||||||
|
```
|
||||||
|
|
||||||
|
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
|
||||||
|
- Machine-specific; results measured locally with the current repo.
|
||||||
|
|
||||||
|
DiskANN (NQ queries, search-only)
|
||||||
|
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
|
||||||
|
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
|
||||||
|
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
|
||||||
|
|
||||||
|
BM25
|
||||||
|
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
|
||||||
|
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
|
||||||
|
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
|
||||||
|
|
||||||
|
Notes
|
||||||
|
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
|
||||||
|
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.
|
||||||
|
After Width: | Height: | Size: 1.3 KiB |
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "pyserini"
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
# sudo pacman -S jdk21-openjdk
|
||||||
|
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
|
||||||
|
# sudo archlinux-java status
|
||||||
|
# sudo archlinux-java set java-21-openjdk
|
||||||
|
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
|
||||||
|
# fish_add_path --global $JAVA_HOME/bin
|
||||||
|
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
|
||||||
|
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(path: str, limit: int | None) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
|
||||||
|
_, ext = os.path.splitext(path)
|
||||||
|
if ext.lower() in {".jsonl", ".json"}:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Not strict JSONL? treat the whole line as the query
|
||||||
|
queries.append(line)
|
||||||
|
continue
|
||||||
|
q = obj.get("query") or obj.get("text") or obj.get("question")
|
||||||
|
if q:
|
||||||
|
queries.append(str(q))
|
||||||
|
else:
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
s = line.strip()
|
||||||
|
if s:
|
||||||
|
queries.append(s)
|
||||||
|
|
||||||
|
if limit is not None and limit > 0:
|
||||||
|
queries = queries[:limit]
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def percentile(values: list[float], p: float) -> float:
|
||||||
|
if not values:
|
||||||
|
return 0.0
|
||||||
|
s = sorted(values)
|
||||||
|
k = (len(s) - 1) * (p / 100.0)
|
||||||
|
f = int(k)
|
||||||
|
c = min(f + 1, len(s) - 1)
|
||||||
|
if f == c:
|
||||||
|
return s[f]
|
||||||
|
return s[f] + (s[c] - s[f]) * (k - f)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--bm25-index",
|
||||||
|
default="benchmarks/data/indices/bm25_index",
|
||||||
|
help="Path to Pyserini Lucene index directory",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--queries",
|
||||||
|
default="benchmarks/data/queries/nq_open.jsonl",
|
||||||
|
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
|
||||||
|
)
|
||||||
|
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
|
||||||
|
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
|
||||||
|
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
|
||||||
|
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
|
||||||
|
ap.add_argument(
|
||||||
|
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
|
||||||
|
)
|
||||||
|
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyserini.search.lucene import LuceneSearcher
|
||||||
|
except Exception:
|
||||||
|
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not os.path.isdir(args.bm25_index):
|
||||||
|
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
queries = load_queries(args.queries, args.limit)
|
||||||
|
if not queries:
|
||||||
|
print("No queries loaded.", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||||
|
print(f"Opening BM25 index: {args.bm25_index}")
|
||||||
|
searcher = LuceneSearcher(args.bm25_index)
|
||||||
|
# Some builds of pyserini require explicit set_bm25; others ignore
|
||||||
|
try:
|
||||||
|
searcher.set_bm25(k1=args.k1, b=args.b)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
total_searches = 0
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for i in range(min(args.warmup, len(queries))):
|
||||||
|
_ = searcher.search(queries[i], k=args.k)
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
for i, q in enumerate(queries):
|
||||||
|
t1 = time.time()
|
||||||
|
hits = searcher.search(q, k=args.k)
|
||||||
|
t2 = time.time()
|
||||||
|
latencies.append(t2 - t1)
|
||||||
|
total_searches += 1
|
||||||
|
|
||||||
|
if args.fetch_docs:
|
||||||
|
# Optional doc fetch to include I/O time
|
||||||
|
for h in hits:
|
||||||
|
try:
|
||||||
|
_ = searcher.doc(h.docid)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if (i + 1) % 50 == 0:
|
||||||
|
print(f"Processed {i + 1}/{len(queries)} queries")
|
||||||
|
|
||||||
|
t1 = time.time()
|
||||||
|
total_time = t1 - t0
|
||||||
|
|
||||||
|
if latencies:
|
||||||
|
avg = mean(latencies)
|
||||||
|
p50 = percentile(latencies, 50)
|
||||||
|
p90 = percentile(latencies, 90)
|
||||||
|
p95 = percentile(latencies, 95)
|
||||||
|
p99 = percentile(latencies, 99)
|
||||||
|
qps = total_searches / total_time if total_time > 0 else 0.0
|
||||||
|
else:
|
||||||
|
avg = p50 = p90 = p95 = p99 = qps = 0.0
|
||||||
|
|
||||||
|
print("BM25 Latency Report")
|
||||||
|
print(f" queries: {total_searches}")
|
||||||
|
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
|
||||||
|
print(f" avg per query: {avg:.6f} s")
|
||||||
|
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
|
||||||
|
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
|
||||||
|
|
||||||
|
if args.report:
|
||||||
|
payload = {
|
||||||
|
"queries": total_searches,
|
||||||
|
"k": args.k,
|
||||||
|
"k1": args.k1,
|
||||||
|
"b": args.b,
|
||||||
|
"avg_s": avg,
|
||||||
|
"p50_s": p50,
|
||||||
|
"p90_s": p90,
|
||||||
|
"p95_s": p95,
|
||||||
|
"p99_s": p99,
|
||||||
|
"total_time_s": total_time,
|
||||||
|
"qps": qps,
|
||||||
|
"index_dir": os.path.abspath(args.bm25_index),
|
||||||
|
"fetch_docs": bool(args.fetch_docs),
|
||||||
|
}
|
||||||
|
with open(args.report, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(payload, f, indent=2)
|
||||||
|
print(f"Saved report to {args.report}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "leann-backend-diskann"
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_queries(path: Path, limit: int | None) -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
with open(path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
obj = json.loads(line)
|
||||||
|
out.append(obj["query"])
|
||||||
|
if limit and len(out) >= limit:
|
||||||
|
break
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
ap = argparse.ArgumentParser(
|
||||||
|
description="DiskANN baseline on real NQ queries (search-only timing)"
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--index-dir",
|
||||||
|
default="benchmarks/data/indices/diskann_rpj_wiki",
|
||||||
|
help="Directory containing DiskANN files",
|
||||||
|
)
|
||||||
|
ap.add_argument("--index-prefix", default="ann")
|
||||||
|
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
|
||||||
|
ap.add_argument("--num-queries", type=int, default=200)
|
||||||
|
ap.add_argument("--top-k", type=int, default=10)
|
||||||
|
ap.add_argument("--complexity", type=int, default=62)
|
||||||
|
ap.add_argument("--threads", type=int, default=1)
|
||||||
|
ap.add_argument("--beam-width", type=int, default=1)
|
||||||
|
ap.add_argument("--cache-mechanism", type=int, default=2)
|
||||||
|
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
index_dir = Path(args.index_dir).resolve()
|
||||||
|
if not index_dir.is_dir():
|
||||||
|
raise SystemExit(f"Index dir not found: {index_dir}")
|
||||||
|
|
||||||
|
qpath = Path(args.queries_file).resolve()
|
||||||
|
if not qpath.exists():
|
||||||
|
raise SystemExit(f"Queries file not found: {qpath}")
|
||||||
|
|
||||||
|
queries = load_queries(qpath, args.num_queries)
|
||||||
|
print(f"Loaded {len(queries)} queries from {qpath}")
|
||||||
|
|
||||||
|
# Compute embeddings once (exclude from timing)
|
||||||
|
from leann.api import compute_embeddings as _compute
|
||||||
|
|
||||||
|
embs = _compute(
|
||||||
|
queries,
|
||||||
|
model_name="facebook/contriever-msmarco",
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
if embs.ndim != 2:
|
||||||
|
raise SystemExit("Embedding compute failed or returned wrong shape")
|
||||||
|
|
||||||
|
# Build searcher
|
||||||
|
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
|
||||||
|
|
||||||
|
index_prefix_path = str(index_dir / args.index_prefix)
|
||||||
|
searcher = _DiskannSearcher(
|
||||||
|
index_prefix_path,
|
||||||
|
num_threads=int(args.threads),
|
||||||
|
cache_mechanism=int(args.cache_mechanism),
|
||||||
|
num_nodes_to_cache=int(args.num_nodes_to_cache),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup (not timed)
|
||||||
|
_ = searcher.search(
|
||||||
|
embs[0:1],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
batch_recompute=False,
|
||||||
|
dedup_node_dis=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Timed loop
|
||||||
|
times: list[float] = []
|
||||||
|
for i in range(embs.shape[0]):
|
||||||
|
t0 = time.time()
|
||||||
|
_ = searcher.search(
|
||||||
|
embs[i : i + 1],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.complexity,
|
||||||
|
beam_width=args.beam_width,
|
||||||
|
prune_ratio=0.0,
|
||||||
|
recompute_embeddings=False,
|
||||||
|
batch_recompute=False,
|
||||||
|
dedup_node_dis=False,
|
||||||
|
)
|
||||||
|
times.append(time.time() - t0)
|
||||||
|
|
||||||
|
times_sorted = sorted(times)
|
||||||
|
avg = float(sum(times) / len(times))
|
||||||
|
p50 = times_sorted[len(times) // 2]
|
||||||
|
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
|
||||||
|
|
||||||
|
print("\nDiskANN (NQ, search-only) Report")
|
||||||
|
print(f" queries: {len(times)}")
|
||||||
|
print(
|
||||||
|
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
|
||||||
|
)
|
||||||
|
print(f" avg per query: {avg:.6f} s")
|
||||||
|
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
|
||||||
|
print(f" QPS: {1.0 / avg:.2f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,14 +3,15 @@
|
|||||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import psutil
|
|
||||||
import gc
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import psutil
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -61,7 +62,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
[sys.executable, "examples/faiss_only.py"],
|
[sys.executable, "benchmarks/faiss_only.py"],
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300,
|
timeout=300,
|
||||||
@@ -83,9 +84,7 @@ def test_faiss_hnsw():
|
|||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if "Peak Memory:" in line:
|
if "Peak Memory:" in line:
|
||||||
peak_memory = float(
|
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"peak_memory": peak_memory}
|
return {"peak_memory": peak_memory}
|
||||||
|
|
||||||
@@ -111,13 +110,12 @@ def test_leann_hnsw():
|
|||||||
|
|
||||||
tracker.checkpoint("After imports")
|
tracker.checkpoint("After imports")
|
||||||
|
|
||||||
|
from leann.api import LeannBuilder
|
||||||
from llama_index.core import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
from leann.api import LeannBuilder, LeannSearcher
|
|
||||||
|
|
||||||
|
|
||||||
# Load and parse documents
|
# Load and parse documents
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"examples/data",
|
"data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
@@ -135,6 +133,7 @@ def test_leann_hnsw():
|
|||||||
nodes = node_parser.get_nodes_from_documents([doc])
|
nodes = node_parser.get_nodes_from_documents([doc])
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
all_texts.append(node.get_content())
|
all_texts.append(node.get_content())
|
||||||
|
print(f"Total number of chunks: {len(all_texts)}")
|
||||||
|
|
||||||
tracker.checkpoint("After text chunking")
|
tracker.checkpoint("After text chunking")
|
||||||
|
|
||||||
@@ -201,11 +200,9 @@ def test_leann_hnsw():
|
|||||||
searcher = LeannSearcher(index_path)
|
searcher = LeannSearcher(index_path)
|
||||||
tracker.checkpoint("After searcher loading")
|
tracker.checkpoint("After searcher loading")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print("Running search queries...")
|
print("Running search queries...")
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
@@ -303,21 +300,15 @@ def main():
|
|||||||
|
|
||||||
print("\nLEANN vs Faiss Performance:")
|
print("\nLEANN vs Faiss Performance:")
|
||||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||||
print(
|
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Storage comparison
|
# Storage comparison
|
||||||
if leann_storage_size > faiss_storage_size:
|
if leann_storage_size > faiss_storage_size:
|
||||||
storage_ratio = leann_storage_size / faiss_storage_size
|
storage_ratio = leann_storage_size / faiss_storage_size
|
||||||
print(
|
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
|
||||||
)
|
|
||||||
elif faiss_storage_size > leann_storage_size:
|
elif faiss_storage_size > leann_storage_size:
|
||||||
storage_ratio = faiss_storage_size / leann_storage_size
|
storage_ratio = faiss_storage_size / leann_storage_size
|
||||||
print(
|
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print(" Storage Size: similar")
|
print(" Storage Size: similar")
|
||||||
else:
|
else:
|
||||||
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
DiskANN vs HNSW Search Performance Comparison
|
||||||
|
|
||||||
|
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||||
|
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||||
|
- HNSW: With recompute enabled (is_recompute=True)
|
||||||
|
- Tests performance across different dataset sizes
|
||||||
|
- Measures search latency, recall, and index size
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import multiprocessing as mp
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
||||||
|
try:
|
||||||
|
mp.set_start_method("fork", force=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_texts(n_docs: int) -> list[str]:
|
||||||
|
"""Create synthetic test documents for benchmarking."""
|
||||||
|
np.random.seed(42)
|
||||||
|
topics = [
|
||||||
|
"machine learning and artificial intelligence",
|
||||||
|
"natural language processing and text analysis",
|
||||||
|
"computer vision and image recognition",
|
||||||
|
"data science and statistical analysis",
|
||||||
|
"deep learning and neural networks",
|
||||||
|
"information retrieval and search engines",
|
||||||
|
"database systems and data management",
|
||||||
|
"software engineering and programming",
|
||||||
|
"cybersecurity and network protection",
|
||||||
|
"cloud computing and distributed systems",
|
||||||
|
]
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
for i in range(n_docs):
|
||||||
|
topic = topics[i % len(topics)]
|
||||||
|
variation = np.random.randint(1, 100)
|
||||||
|
text = (
|
||||||
|
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||||
|
f"Additional information about {topic} with details and examples. "
|
||||||
|
f"Technical discussion of {topic} including implementation aspects."
|
||||||
|
)
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_backend(
|
||||||
|
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Benchmark a specific backend with the given configuration."""
|
||||||
|
from leann.api import LeannBuilder, LeannSearcher
|
||||||
|
|
||||||
|
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend_name,
|
||||||
|
embedding_model="facebook/contriever",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
**backend_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
builder.add_text(text)
|
||||||
|
|
||||||
|
builder.build_index(index_path)
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Measure index size
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||||
|
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||||
|
size_mb = total_size / (1024 * 1024)
|
||||||
|
|
||||||
|
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||||
|
|
||||||
|
# Search benchmark
|
||||||
|
print("🔍 Running search benchmark...")
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
results = searcher.search(query, top_k=5)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
all_results.append(results)
|
||||||
|
|
||||||
|
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||||
|
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||||
|
|
||||||
|
# Check for valid scores (detect -inf issues)
|
||||||
|
all_scores = [
|
||||||
|
result.score
|
||||||
|
for results in all_results
|
||||||
|
for result in results
|
||||||
|
if result.score is not None
|
||||||
|
]
|
||||||
|
valid_scores = [
|
||||||
|
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||||
|
]
|
||||||
|
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||||
|
|
||||||
|
# Clean up (ensure embedding server shutdown and object GC)
|
||||||
|
try:
|
||||||
|
if hasattr(searcher, "cleanup"):
|
||||||
|
searcher.cleanup()
|
||||||
|
del searcher
|
||||||
|
del builder
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"build_time": build_time,
|
||||||
|
"avg_search_time_ms": avg_search_time,
|
||||||
|
"index_size_mb": size_mb,
|
||||||
|
"score_validity_rate": score_validity_rate,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||||
|
"""Run performance comparison between DiskANN and HNSW."""
|
||||||
|
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||||
|
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
texts = create_test_texts(n_docs)
|
||||||
|
test_queries = [
|
||||||
|
"machine learning algorithms",
|
||||||
|
"natural language processing",
|
||||||
|
"computer vision techniques",
|
||||||
|
"data analysis methods",
|
||||||
|
"neural network architectures",
|
||||||
|
"database query optimization",
|
||||||
|
"software development practices",
|
||||||
|
"security vulnerabilities",
|
||||||
|
"cloud infrastructure",
|
||||||
|
"distributed computing",
|
||||||
|
][:n_queries]
|
||||||
|
|
||||||
|
# HNSW benchmark
|
||||||
|
hnsw_results = benchmark_backend(
|
||||||
|
backend_name="hnsw",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable recompute for fair comparison
|
||||||
|
"M": 16,
|
||||||
|
"efConstruction": 200,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# DiskANN benchmark
|
||||||
|
diskann_results = benchmark_backend(
|
||||||
|
backend_name="diskann",
|
||||||
|
texts=texts,
|
||||||
|
test_queries=test_queries,
|
||||||
|
backend_kwargs={
|
||||||
|
"is_recompute": True, # Enable graph partitioning
|
||||||
|
"num_neighbors": 32,
|
||||||
|
"search_list_size": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
print("\n📈 Performance Comparison Results")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||||
|
print(f"{'-' * 60}")
|
||||||
|
|
||||||
|
# Build time comparison
|
||||||
|
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||||
|
print(
|
||||||
|
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search time comparison
|
||||||
|
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||||
|
print(
|
||||||
|
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Index size comparison
|
||||||
|
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||||
|
print(
|
||||||
|
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score validity
|
||||||
|
print(
|
||||||
|
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print("\n🎯 Summary:")
|
||||||
|
if search_speedup > 1:
|
||||||
|
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||||
|
else:
|
||||||
|
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||||
|
|
||||||
|
if size_ratio > 1:
|
||||||
|
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||||
|
else:
|
||||||
|
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle help request
|
||||||
|
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||||
|
print()
|
||||||
|
print("Arguments:")
|
||||||
|
print(" n_docs Number of documents to index (default: 500)")
|
||||||
|
print(" n_queries Number of test queries to run (default: 10)")
|
||||||
|
print()
|
||||||
|
print("Examples:")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||||
|
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||||
|
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||||
|
|
||||||
|
print("DiskANN vs HNSW Performance Comparison")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||||
|
print()
|
||||||
|
|
||||||
|
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Benchmark interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Benchmark failed: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
finally:
|
||||||
|
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
||||||
|
try:
|
||||||
|
gc.collect()
|
||||||
|
print("\n🧹 Cleanup completed")
|
||||||
|
# Flush stdio to ensure message is visible before hard-exit
|
||||||
|
try:
|
||||||
|
import sys as _sys
|
||||||
|
|
||||||
|
_sys.stdout.flush()
|
||||||
|
_sys.stderr.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
||||||
|
import os as _os
|
||||||
|
|
||||||
|
_os._exit(0)
|
||||||
141
benchmarks/enron_emails/README.md
Normal file
141
benchmarks/enron_emails/README.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Enron Emails Benchmark
|
||||||
|
|
||||||
|
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
|
||||||
|
|
||||||
|
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||||
|
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||||
|
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
|
||||||
|
|
||||||
|
## Layout
|
||||||
|
|
||||||
|
benchmarks/enron_emails/
|
||||||
|
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||||
|
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
|
||||||
|
- data/: Generated passages, queries, embeddings-related files
|
||||||
|
- baseline/: FAISS Flat baseline files
|
||||||
|
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
1) Prepare the data and index
|
||||||
|
|
||||||
|
cd benchmarks/enron_emails
|
||||||
|
python setup_enron_emails.py --data-dir data
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||||
|
Alternatively, pass a local path to `--emails-csv`.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||||
|
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||||
|
|
||||||
|
2) Run recall evaluation (Stage 2)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||||
|
|
||||||
|
3) Complexity sweep (Stage 3)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||||
|
|
||||||
|
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||||
|
|
||||||
|
4) Index comparison (Stage 4)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
|
||||||
|
|
||||||
|
5) Generation evaluation (Stage 5)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
|
||||||
|
|
||||||
|
6) Combined index + generation evaluation (Stages 4+5, recommended)
|
||||||
|
|
||||||
|
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||||
|
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
|
||||||
|
- `--baseline-dir` defaults to `baseline`
|
||||||
|
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||||
|
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
|
||||||
|
- `--model-name` defaults to `Qwen/Qwen3-8B`
|
||||||
|
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||||
|
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||||
|
- --baseline-dir baseline (where FAISS baseline lives)
|
||||||
|
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
|
||||||
|
- --llm-backend hf|vllm (LLM backend for generation)
|
||||||
|
- --model-name Qwen/Qwen3-8B (LLM model for generation)
|
||||||
|
- --max-queries 1000 (limit number of queries for evaluation)
|
||||||
|
|
||||||
|
## Files Produced
|
||||||
|
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||||
|
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||||
|
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||||
|
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
|
||||||
|
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||||
|
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
|
||||||
|
|
||||||
|
## Stages Summary
|
||||||
|
|
||||||
|
- Stage 2 (Recall@3):
|
||||||
|
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||||
|
- Compact index runs with `recompute_embeddings=True`.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search for Complexity):
|
||||||
|
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||||
|
|
||||||
|
- Stage 4 (Index Comparison):
|
||||||
|
- Reports .index-only sizes for compact vs non-compact.
|
||||||
|
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||||
|
- Stores retrieval results for Stage 5 generation evaluation.
|
||||||
|
- Fails fast if compact recompute cannot run.
|
||||||
|
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||||
|
- First from the current run (when running `--stage all`), otherwise
|
||||||
|
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||||
|
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||||
|
|
||||||
|
- Stage 5 (Generation Evaluation):
|
||||||
|
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
|
||||||
|
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
|
||||||
|
- Measures generation timing separately from search timing.
|
||||||
|
- Requires Stage 4 results (no additional searching performed).
|
||||||
|
|
||||||
|
## Example Results
|
||||||
|
|
||||||
|
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
|
||||||
|
|
||||||
|
- Stage 3 (Binary Search):
|
||||||
|
- Minimal complexity achieving 90% Recall@3: 88
|
||||||
|
- Sampled points:
|
||||||
|
- C=8 → 59.9% Recall@3
|
||||||
|
- C=72 → 89.4% Recall@3
|
||||||
|
- C=88 → 90.2% Recall@3
|
||||||
|
- C=96 → 90.7% Recall@3
|
||||||
|
- C=112 → 91.1% Recall@3
|
||||||
|
- C=136 → 91.3% Recall@3
|
||||||
|
- C=256 → 92.0% Recall@3
|
||||||
|
|
||||||
|
- Stage 4 (Index Sizes, .index only):
|
||||||
|
- Compact: ~2.2 MB
|
||||||
|
- Non-compact: ~82.0 MB
|
||||||
|
- Storage saving by compact: ~97.3%
|
||||||
|
|
||||||
|
- Stage 4 (Search Timing, 988 queries, complexity=88):
|
||||||
|
- Non-compact (no recompute): ~0.0075 s avg per query
|
||||||
|
- Compact (with recompute): ~1.981 s avg per query
|
||||||
|
- Speed ratio (non-compact/compact): ~0.0038x
|
||||||
|
|
||||||
|
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
|
||||||
|
- Average generation time: ~22.302 s per query
|
||||||
|
- Total queries processed: 988
|
||||||
|
- LLM backend: HuggingFace transformers
|
||||||
|
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||||
|
|
||||||
|
Full JSON output is saved by the script (see `--output`), e.g.:
|
||||||
|
`benchmarks/enron_emails/results_enron_stage45.json`.
|
||||||
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
downloads/
|
||||||
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
@@ -0,0 +1,614 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||||
|
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||||
|
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||||
|
On errors, fail fast without fallbacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Compute query embedding with the same model/mode as the index
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS Flat ground truth
|
||||||
|
n = q_emb.shape[0]
|
||||||
|
k = 3
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(q_emb),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN (may require embedding server depending on index configuration)
|
||||||
|
results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {r.id for r in results}
|
||||||
|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN: {list(test_ids)}")
|
||||||
|
print(f" ∩: {list(intersection)}")
|
||||||
|
|
||||||
|
avg = total_recall / max(1, len(queries))
|
||||||
|
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||||
|
return avg
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class EnronEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
queries: list[str] = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
if "query" in data:
|
||||||
|
queries.append(data["query"])
|
||||||
|
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||||
|
|
||||||
|
sizes["index_only_mb"] = (
|
||||||
|
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source and embedding model
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
# Load all passages and ids
|
||||||
|
ids: list[str] = []
|
||||||
|
texts: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
texts.append(data["text"])
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts,
|
||||||
|
meta["embedding_model"],
|
||||||
|
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False,
|
||||||
|
is_compact=False,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare search speed for non-compact vs compact indexes."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
results: dict = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||||
|
}
|
||||||
|
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
# Non-compact (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Compact (with recompute). Fail fast if it cannot run.
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
for q in test_queries:
|
||||||
|
t0 = time.time()
|
||||||
|
docs = compact_searcher.search(
|
||||||
|
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
results["compact"]["search_times"].append(time.time() - t0)
|
||||||
|
|
||||||
|
# Store retrieval results for Stage 5
|
||||||
|
results["retrieval_results"].append(
|
||||||
|
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||||
|
)
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
if results["non_compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
if results["compact"]["search_times"]:
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
if results["avg_search_times"].get("compact", 0) > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = 0.0
|
||||||
|
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_complexity(
|
||||||
|
self,
|
||||||
|
recall_eval: "RecallEvaluator",
|
||||||
|
queries: list[str],
|
||||||
|
target: float = 0.90,
|
||||||
|
c_min: int = 8,
|
||||||
|
c_max: int = 256,
|
||||||
|
max_iters: int = 10,
|
||||||
|
recompute: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||||
|
|
||||||
|
def round_c(x: int) -> int:
|
||||||
|
# snap to multiple of 8 like other benches typically do
|
||||||
|
return max(1, int((x + 7) // 8) * 8)
|
||||||
|
|
||||||
|
metrics: list[dict] = []
|
||||||
|
|
||||||
|
lo = round_c(c_min)
|
||||||
|
hi = round_c(c_max)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||||
|
r_lo = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=lo, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
cap = 1024
|
||||||
|
while r_hi < target and hi < cap:
|
||||||
|
lo = hi
|
||||||
|
r_lo = r_hi
|
||||||
|
hi = round_c(hi * 2)
|
||||||
|
r_hi = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=hi, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||||
|
|
||||||
|
if r_hi < target:
|
||||||
|
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||||
|
print("📈 Observations:")
|
||||||
|
for m in metrics:
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||||
|
|
||||||
|
# Binary search within [lo, hi]
|
||||||
|
best = hi
|
||||||
|
iters = 0
|
||||||
|
while lo < hi and iters < max_iters:
|
||||||
|
mid = round_c((lo + hi) // 2)
|
||||||
|
r_mid = recall_eval.evaluate_recall_at_3(
|
||||||
|
queries, complexity=mid, recompute_embeddings=recompute
|
||||||
|
)
|
||||||
|
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||||
|
if r_mid >= target:
|
||||||
|
best = mid
|
||||||
|
hi = mid
|
||||||
|
else:
|
||||||
|
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||||
|
iters += 1
|
||||||
|
|
||||||
|
print("📈 Binary search results (sampled points):")
|
||||||
|
# Print unique complexity entries ordered by complexity
|
||||||
|
for m in sorted(
|
||||||
|
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||||
|
):
|
||||||
|
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||||
|
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||||
|
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "5", "all", "45"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||||
|
)
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||||
|
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Resolve queries file: if default path not found, fall back to index's directory
|
||||||
|
if not os.path.exists(args.queries):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
idx_dir = Path(args.index).parent
|
||||||
|
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||||
|
if fallback_q.exists():
|
||||||
|
args.queries = str(fallback_q)
|
||||||
|
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
results_out: dict = {}
|
||||||
|
|
||||||
|
if args.stage in ("2", "all"):
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[:10]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||||
|
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("3", "all"):
|
||||||
|
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
queries = queries[: args.max_queries]
|
||||||
|
print(f"🧪 Using first {len(queries)} queries")
|
||||||
|
|
||||||
|
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact evaluator for binary search with recompute=False
|
||||||
|
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
sweep = enron_eval.evaluate_complexity(
|
||||||
|
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||||
|
)
|
||||||
|
results_out["stage3"] = sweep
|
||||||
|
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"stage3": sweep}, f, indent=2)
|
||||||
|
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||||
|
evaluator_nc.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 3 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("4", "all", "45"):
|
||||||
|
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
enron_eval = EnronEvaluator(args.index)
|
||||||
|
queries = enron_eval.load_queries(args.queries)
|
||||||
|
test_q = queries[: min(args.max_queries, len(queries))]
|
||||||
|
|
||||||
|
current_sizes = enron_eval.analyze_index_sizes()
|
||||||
|
# Build non-compact index for comparison (no fallback)
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
index_path = Path(args.index)
|
||||||
|
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||||
|
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||||
|
nc_eval = EnronEvaluator(non_compact_path)
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_sizes.get("index_only_mb", 0) > 0
|
||||||
|
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||||
|
):
|
||||||
|
storage_saving_percent = max(
|
||||||
|
0.0,
|
||||||
|
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
storage_saving_percent = 0.0
|
||||||
|
|
||||||
|
if args.complexity is None:
|
||||||
|
# Prefer in-session Stage 3 result
|
||||||
|
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||||
|
complexity = results_out["stage3"]["best_complexity"]
|
||||||
|
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||||
|
else:
|
||||||
|
# Try to load last saved Stage 3 result near index
|
||||||
|
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||||
|
if default_stage3_path.exists():
|
||||||
|
with open(default_stage3_path, encoding="utf-8") as f:
|
||||||
|
prev = json.load(f)
|
||||||
|
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||||
|
if complexity is None:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
||||||
|
)
|
||||||
|
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||||
|
else:
|
||||||
|
raise SystemExit(
|
||||||
|
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
comp = enron_eval.compare_index_performance(
|
||||||
|
non_compact_path, args.index, test_q, complexity=complexity
|
||||||
|
)
|
||||||
|
results_out["stage4"] = {
|
||||||
|
"current_index": current_sizes,
|
||||||
|
"non_compact_index": non_compact_sizes,
|
||||||
|
"storage_saving_percent": storage_saving_percent,
|
||||||
|
"performance_comparison": comp,
|
||||||
|
}
|
||||||
|
nc_eval.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
enron_eval.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("5", "all"):
|
||||||
|
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||||
|
|
||||||
|
# Check if Stage 4 results exist
|
||||||
|
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||||
|
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||||
|
print("💡 Run Stage 4 first or use --stage all")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||||
|
if not retrieval_results:
|
||||||
|
print("❌ No retrieval results found from Stage 4")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||||
|
|
||||||
|
# Load LLM
|
||||||
|
try:
|
||||||
|
if args.llm_backend == "hf":
|
||||||
|
tokenizer, model = load_hf_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_hf(tokenizer, model, prompt)
|
||||||
|
else: # vllm
|
||||||
|
llm, sampling_params = load_vllm_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_vllm(llm, sampling_params, prompt)
|
||||||
|
|
||||||
|
# Run generation using stored retrieval results
|
||||||
|
import time
|
||||||
|
|
||||||
|
from llm_utils import create_prompt
|
||||||
|
|
||||||
|
generation_times = []
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
print("🤖 Running generation on pre-retrieved results...")
|
||||||
|
for i, item in enumerate(retrieval_results):
|
||||||
|
query = item["query"]
|
||||||
|
retrieved_docs = item["retrieved_docs"]
|
||||||
|
|
||||||
|
# Prepare context from retrieved docs
|
||||||
|
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||||
|
prompt = create_prompt(context, query, "emails")
|
||||||
|
|
||||||
|
# Time generation only
|
||||||
|
gen_start = time.time()
|
||||||
|
response = llm_func(prompt)
|
||||||
|
gen_time = time.time() - gen_start
|
||||||
|
|
||||||
|
generation_times.append(gen_time)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||||
|
|
||||||
|
print("\n📊 Generation Results:")
|
||||||
|
print(f" Total Queries: {len(retrieval_results)}")
|
||||||
|
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||||
|
print(" (Search time from Stage 4)")
|
||||||
|
|
||||||
|
results_out["stage5"] = {
|
||||||
|
"total_queries": len(retrieval_results),
|
||||||
|
"avg_generation_time": avg_gen_time,
|
||||||
|
"generation_times": generation_times,
|
||||||
|
"responses": responses,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Show sample results
|
||||||
|
print("\n📝 Sample Results:")
|
||||||
|
for i in range(min(3, len(retrieval_results))):
|
||||||
|
query = retrieval_results[i]["query"]
|
||||||
|
response = responses[i]
|
||||||
|
print(f" Q{i + 1}: {query[:60]}...")
|
||||||
|
print(f" A{i + 1}: {response[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Generation evaluation failed: {e}")
|
||||||
|
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||||
|
|
||||||
|
print("✅ Stage 5 completed!\n")
|
||||||
|
|
||||||
|
if args.output and results_out:
|
||||||
|
with open(args.output, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(results_out, f, indent=2)
|
||||||
|
print(f"📝 Saved results to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
"""
|
||||||
|
Enron Emails Benchmark Setup Script
|
||||||
|
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from email import message_from_string
|
||||||
|
from email.policy import default
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class EnronSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||||
|
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||||
|
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
self.downloads_dir = self.data_dir / "downloads"
|
||||||
|
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Dataset acquisition
|
||||||
|
# ----------------------------
|
||||||
|
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||||
|
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||||
|
if emails_csv:
|
||||||
|
p = Path(emails_csv)
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||||
|
|
||||||
|
api = KaggleApi()
|
||||||
|
api.authenticate()
|
||||||
|
api.dataset_download_files(
|
||||||
|
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||||
|
)
|
||||||
|
candidate = self.downloads_dir / "emails.csv"
|
||||||
|
if candidate.exists():
|
||||||
|
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||||
|
return str(candidate)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Data preparation
|
||||||
|
# ----------------------------
|
||||||
|
@staticmethod
|
||||||
|
def _extract_message_id(raw_email: str) -> str:
|
||||||
|
msg = message_from_string(raw_email, policy=default)
|
||||||
|
val = msg.get("Message-ID", "")
|
||||||
|
if val.startswith("<") and val.endswith(">"):
|
||||||
|
val = val[1:-1]
|
||||||
|
return val or ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||||
|
parts = raw_email.split("\n\n", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[0].strip(), parts[1].strip()
|
||||||
|
# Heuristic fallback
|
||||||
|
first_lines = raw_email.splitlines()
|
||||||
|
if first_lines and ":" in first_lines[0]:
|
||||||
|
return raw_email.strip(), ""
|
||||||
|
return "", raw_email.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||||
|
text = (text or "").strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
if chunk_words <= 0:
|
||||||
|
return [text]
|
||||||
|
words = text.split()
|
||||||
|
if not words:
|
||||||
|
return []
|
||||||
|
limit = len(words)
|
||||||
|
if not keep_last:
|
||||||
|
limit = (len(words) // chunk_words) * chunk_words
|
||||||
|
if limit == 0:
|
||||||
|
return []
|
||||||
|
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||||
|
return [c for c in (s.strip() for s in chunks) if c]
|
||||||
|
|
||||||
|
def _iter_passages_from_csv(
|
||||||
|
self,
|
||||||
|
emails_csv: Path,
|
||||||
|
chunk_words: int = 256,
|
||||||
|
keep_last_header: bool = True,
|
||||||
|
keep_last_body: bool = True,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> Iterable[dict]:
|
||||||
|
with open(emails_csv, encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
count = 0
|
||||||
|
for i, row in enumerate(reader):
|
||||||
|
if max_emails is not None and count >= max_emails:
|
||||||
|
break
|
||||||
|
|
||||||
|
raw_message = row.get("message", "")
|
||||||
|
email_file_id = row.get("file", "")
|
||||||
|
|
||||||
|
if not raw_message.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
message_id = self._extract_message_id(raw_message)
|
||||||
|
if not message_id:
|
||||||
|
# Fallback ID based on CSV position and file path
|
||||||
|
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||||
|
message_id = f"enron_{i}_{safe_file}"
|
||||||
|
|
||||||
|
header, body = self._split_header_body(raw_message)
|
||||||
|
|
||||||
|
# Header chunks
|
||||||
|
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": True,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Body chunks
|
||||||
|
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||||
|
yield {
|
||||||
|
"text": chunk,
|
||||||
|
"metadata": {
|
||||||
|
"message_id": message_id,
|
||||||
|
"is_header": False,
|
||||||
|
"email_file_id": email_file_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Build LEANN index and FAISS baseline
|
||||||
|
# ----------------------------
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
emails_csv: Optional[str],
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
chunk_words: int = 256,
|
||||||
|
max_emails: int | None = None,
|
||||||
|
) -> str:
|
||||||
|
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||||
|
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||||
|
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream passages and add to builder
|
||||||
|
preview_written = 0
|
||||||
|
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||||
|
for p in self._iter_passages_from_csv(
|
||||||
|
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||||
|
):
|
||||||
|
builder.add_text(p["text"], metadata=p["metadata"])
|
||||||
|
if preview_written < 200:
|
||||||
|
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||||
|
preview_written += 1
|
||||||
|
|
||||||
|
print(f"🔨 Building index at {self.index_path}...")
|
||||||
|
builder.build_index(str(self.index_path))
|
||||||
|
print("✅ LEANN index built!")
|
||||||
|
return str(self.index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||||
|
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read meta for passage source and embedding model
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
# Load passages from builder output so IDs match LEANN
|
||||||
|
passages: list[str] = []
|
||||||
|
passage_ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"]) # builder-assigned ID
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||||
|
print(f"🤖 Embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS IndexFlatIP
|
||||||
|
dim = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dim)
|
||||||
|
emb_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||||
|
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as pf:
|
||||||
|
pickle.dump(passage_ids, pf)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved: {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# ----------------------------
|
||||||
|
# Queries (optional): prepare evaluation queries file
|
||||||
|
# ----------------------------
|
||||||
|
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||||
|
print(
|
||||||
|
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Failed to load dataset: {e}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
kept = 0
|
||||||
|
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||||
|
for i, item in enumerate(ds):
|
||||||
|
how_realistic = float(item.get("how_realistic", 0.0))
|
||||||
|
if how_realistic < min_realism:
|
||||||
|
continue
|
||||||
|
qid = str(item.get("id", f"enron_q_{i}"))
|
||||||
|
query = item.get("question", "")
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
record = {
|
||||||
|
"id": qid,
|
||||||
|
"query": query,
|
||||||
|
# For reference only, not used in recall metric below
|
||||||
|
"gt_message_ids": item.get("message_ids", []),
|
||||||
|
}
|
||||||
|
out.write(json.dumps(record) + "\n")
|
||||||
|
kept += 1
|
||||||
|
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||||
|
return self.queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--emails-csv",
|
||||||
|
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model for LEANN",
|
||||||
|
)
|
||||||
|
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||||
|
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||||
|
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
setup = EnronSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Build index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
emails_csv=args.emails_csv,
|
||||||
|
backend=args.backend,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
|
chunk_words=args.chunk_words,
|
||||||
|
max_emails=args.max_emails,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build FAISS baseline from the same passages & embeddings
|
||||||
|
setup.build_faiss_flat_baseline(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LEANN index build and baseline")
|
||||||
|
|
||||||
|
# Queries file (optional)
|
||||||
|
if not args.skip_queries:
|
||||||
|
setup.prepare_queries()
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping query preparation")
|
||||||
|
|
||||||
|
print("\n🎉 Enron Emails setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("Next steps:")
|
||||||
|
print(
|
||||||
|
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Test only Faiss HNSW"""
|
"""Test only Faiss HNSW"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import gc
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_usage():
|
def get_memory_usage():
|
||||||
@@ -37,20 +37,20 @@ def main():
|
|||||||
import faiss
|
import faiss
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Faiss is not installed.")
|
print("Faiss is not installed.")
|
||||||
print("Please install it with `uv pip install faiss-cpu`")
|
print(
|
||||||
|
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||||
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from llama_index.core import (
|
from llama_index.core import (
|
||||||
SimpleDirectoryReader,
|
|
||||||
VectorStoreIndex,
|
|
||||||
StorageContext,
|
|
||||||
Settings,
|
Settings,
|
||||||
node_parser,
|
SimpleDirectoryReader,
|
||||||
Document,
|
StorageContext,
|
||||||
|
VectorStoreIndex,
|
||||||
)
|
)
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
|
||||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||||
|
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||||
|
|
||||||
tracker = MemoryTracker("Faiss HNSW")
|
tracker = MemoryTracker("Faiss HNSW")
|
||||||
tracker.checkpoint("Initial")
|
tracker.checkpoint("Initial")
|
||||||
@@ -65,7 +65,7 @@ def main():
|
|||||||
tracker.checkpoint("After Faiss index creation")
|
tracker.checkpoint("After Faiss index creation")
|
||||||
|
|
||||||
documents = SimpleDirectoryReader(
|
documents = SimpleDirectoryReader(
|
||||||
"examples/data",
|
"data",
|
||||||
recursive=True,
|
recursive=True,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
required_exts=[".pdf", ".txt", ".md"],
|
required_exts=[".pdf", ".txt", ".md"],
|
||||||
@@ -90,8 +90,9 @@ def main():
|
|||||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||||
)
|
)
|
||||||
from llama_index.core import load_index_from_storage
|
from llama_index.core import load_index_from_storage
|
||||||
|
|
||||||
index = load_index_from_storage(storage_context=storage_context)
|
index = load_index_from_storage(storage_context=storage_context)
|
||||||
print(f"Index loaded from ./storage_faiss")
|
print("Index loaded from ./storage_faiss")
|
||||||
tracker.checkpoint("After loading existing index")
|
tracker.checkpoint("After loading existing index")
|
||||||
index_loaded = True
|
index_loaded = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -99,6 +100,7 @@ def main():
|
|||||||
print("Cleaning up corrupted index and building new one...")
|
print("Cleaning up corrupted index and building new one...")
|
||||||
# Clean up corrupted index
|
# Clean up corrupted index
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
if os.path.exists("./storage_faiss"):
|
if os.path.exists("./storage_faiss"):
|
||||||
shutil.rmtree("./storage_faiss")
|
shutil.rmtree("./storage_faiss")
|
||||||
|
|
||||||
@@ -109,9 +111,7 @@ def main():
|
|||||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||||
index = VectorStoreIndex.from_documents(
|
index = VectorStoreIndex.from_documents(
|
||||||
documents,
|
documents, storage_context=storage_context, transformations=[node_parser]
|
||||||
storage_context=storage_context,
|
|
||||||
transformations=[node_parser]
|
|
||||||
)
|
)
|
||||||
tracker.checkpoint("After index building")
|
tracker.checkpoint("After index building")
|
||||||
|
|
||||||
@@ -127,7 +127,7 @@ def main():
|
|||||||
|
|
||||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||||
queries = [
|
queries = [
|
||||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||||
"What is LEANN and how does it work?",
|
"What is LEANN and how does it work?",
|
||||||
"华为诺亚方舟实验室的主要研究内容",
|
"华为诺亚方舟实验室的主要研究内容",
|
||||||
]
|
]
|
||||||
115
benchmarks/financebench/README.md
Normal file
115
benchmarks/financebench/README.md
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
# FinanceBench Benchmark for LEANN-RAG
|
||||||
|
|
||||||
|
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||||
|
- **Questions**: 150 financial Q&A examples
|
||||||
|
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||||
|
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||||
|
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||||
|
|
||||||
|
## Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
benchmarks/financebench/
|
||||||
|
├── setup_financebench.py # Downloads PDFs and builds index
|
||||||
|
├── evaluate_financebench.py # Intelligent evaluation script
|
||||||
|
├── data/
|
||||||
|
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||||
|
│ ├── pdfs/ # Downloaded financial documents
|
||||||
|
│ └── index/ # LEANN indexes
|
||||||
|
│ └── financebench_full_hnsw.leann
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### 1. Setup (Download & Build Index)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd benchmarks/financebench
|
||||||
|
python setup_financebench.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Download the 150 Q&A examples
|
||||||
|
- Download all 368 PDF documents (parallel processing)
|
||||||
|
- Build a LEANN index from 53K+ text chunks
|
||||||
|
- Verify setup with test query
|
||||||
|
|
||||||
|
### 2. Evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic retrieval evaluation
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||||
|
|
||||||
|
|
||||||
|
# RAG generation evaluation with Qwen3-8B
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Methods
|
||||||
|
|
||||||
|
### Retrieval Evaluation
|
||||||
|
Uses intelligent matching with three strategies:
|
||||||
|
1. **Exact text overlap** - Direct substring matches
|
||||||
|
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||||
|
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||||
|
|
||||||
|
### QA Evaluation
|
||||||
|
LLM-based answer evaluation using GPT-4o:
|
||||||
|
- Handles numerical rounding and equivalent representations
|
||||||
|
- Considers fractions, percentages, and decimal equivalents
|
||||||
|
- Evaluates semantic meaning rather than exact text match
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
|
||||||
|
|
||||||
|
**Retrieval Metrics:**
|
||||||
|
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
|
||||||
|
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
|
||||||
|
- **Number Match Rate**: 120.7% (key financial figures matched)*
|
||||||
|
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
|
||||||
|
- **Average Search Time**: 0.097s
|
||||||
|
|
||||||
|
**QA Metrics:**
|
||||||
|
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
|
||||||
|
- **Average QA Time**: 4.71s (end-to-end response time)
|
||||||
|
|
||||||
|
**System Performance:**
|
||||||
|
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||||
|
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||||
|
|
||||||
|
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
|
||||||
|
|
||||||
|
### LEANN-RAG Generation Performance (Qwen3-8B)
|
||||||
|
|
||||||
|
- **Stage 4 (Index Comparison):**
|
||||||
|
- Compact Index: 5.0 MB
|
||||||
|
- Non-compact Index: 172.2 MB
|
||||||
|
- **Storage Saving**: 97.1%
|
||||||
|
- **Search Performance**:
|
||||||
|
- Non-compact (no recompute): 0.009s avg per query
|
||||||
|
- Compact (with recompute): 2.203s avg per query
|
||||||
|
- Speed ratio: 0.004x
|
||||||
|
|
||||||
|
**Generation Evaluation (20 queries, complexity=64):**
|
||||||
|
- **Average Search Time**: 1.638s per query
|
||||||
|
- **Average Generation Time**: 45.957s per query
|
||||||
|
- **LLM Backend**: HuggingFace transformers
|
||||||
|
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||||
|
- **Total Questions Processed**: 20
|
||||||
|
|
||||||
|
## Options
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use different backends
|
||||||
|
python setup_financebench.py --backend diskann
|
||||||
|
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||||
|
|
||||||
|
# Use different embedding models
|
||||||
|
python setup_financebench.py --embedding-model facebook/contriever
|
||||||
|
```
|
||||||
923
benchmarks/financebench/evaluate_financebench.py
Executable file
923
benchmarks/financebench/evaluate_financebench.py
Executable file
@@ -0,0 +1,923 @@
|
|||||||
|
"""
|
||||||
|
FinanceBench Evaluation Script - Modular Recall-based Evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
from leann import LeannChat, LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Load FAISS flat baseline
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.passage_ids = pickle.load(f)
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 for given queries at specified complexity"""
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = len(queries)
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Get ground truth: search with FAISS flat
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
|
||||||
|
query_embedding = compute_embeddings(
|
||||||
|
[query],
|
||||||
|
self.searcher.embedding_model,
|
||||||
|
mode=self.searcher.embedding_mode,
|
||||||
|
use_server=False,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||||
|
n = query_embedding.shape[0] # Number of queries
|
||||||
|
k = 3 # Number of nearest neighbors
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(query_embedding),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the results
|
||||||
|
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN at specified complexity
|
||||||
|
test_results = self.searcher.search(
|
||||||
|
query,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {result.id for result in test_results}
|
||||||
|
|
||||||
|
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class FinanceBenchEvaluator:
|
||||||
|
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
|
||||||
|
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||||
|
|
||||||
|
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||||
|
"""Load FinanceBench dataset"""
|
||||||
|
data = []
|
||||||
|
with open(dataset_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data.append(json.loads(line))
|
||||||
|
|
||||||
|
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||||
|
return data
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes with and without embeddings"""
|
||||||
|
|
||||||
|
print("📏 Analyzing index sizes...")
|
||||||
|
|
||||||
|
# Get all index-related files
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem # Remove .leann extension
|
||||||
|
|
||||||
|
sizes = {}
|
||||||
|
total_with_embeddings = 0
|
||||||
|
|
||||||
|
# Core index files
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||||
|
|
||||||
|
for file_path, name in [
|
||||||
|
(index_file, "index"),
|
||||||
|
(meta_file, "metadata"),
|
||||||
|
(passages_file, "passages_text"),
|
||||||
|
(passages_idx_file, "passages_index"),
|
||||||
|
]:
|
||||||
|
if file_path.exists():
|
||||||
|
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||||
|
sizes[name] = size_mb
|
||||||
|
total_with_embeddings += size_mb
|
||||||
|
|
||||||
|
else:
|
||||||
|
sizes[name] = 0
|
||||||
|
|
||||||
|
sizes["total_with_embeddings"] = total_with_embeddings
|
||||||
|
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
|
||||||
|
|
||||||
|
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
|
||||||
|
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
|
||||||
|
"""Create a compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Build compact index with same passages
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||||
|
is_compact=True, # Enable compact storage
|
||||||
|
**meta.get("backend_kwargs", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all passages
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||||
|
|
||||||
|
print(f"🔨 Building compact index at {compact_index_path}...")
|
||||||
|
builder.build_index(compact_index_path)
|
||||||
|
|
||||||
|
# Analyze the compact index size
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
|
||||||
|
compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
compact_sizes["index_type"] = "compact"
|
||||||
|
|
||||||
|
return compact_sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building non-compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Build non-compact index with same passages
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=meta["embedding_model"],
|
||||||
|
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||||
|
is_recompute=False, # Disable recompute (store embeddings)
|
||||||
|
is_compact=False, # Disable compact storage
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all passages
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||||
|
|
||||||
|
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
|
||||||
|
builder.build_index(non_compact_index_path)
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare performance between non-compact and compact indexes"""
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from leann import LeannSearcher
|
||||||
|
|
||||||
|
# Test queries
|
||||||
|
test_queries = [item["question"] for item in test_data[:5]]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test non-compact index (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
query, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["non_compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Test compact index (with recompute)
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
query, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance ratio
|
||||||
|
if results["avg_search_times"]["compact"] > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||||
|
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def evaluate_timing_breakdown(
|
||||||
|
self, data: list[dict], max_samples: Optional[int] = None
|
||||||
|
) -> dict:
|
||||||
|
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
|
||||||
|
if not self.chat or not self.openai_client:
|
||||||
|
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
|
||||||
|
return {
|
||||||
|
"total_questions": 0,
|
||||||
|
"avg_search_time": 0.0,
|
||||||
|
"avg_generation_time": 0.0,
|
||||||
|
"avg_total_time": 0.0,
|
||||||
|
"accuracy": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
|
||||||
|
|
||||||
|
if max_samples:
|
||||||
|
data = data[:max_samples]
|
||||||
|
print(f"📝 Using first {max_samples} samples for timing evaluation")
|
||||||
|
|
||||||
|
search_times = []
|
||||||
|
generation_times = []
|
||||||
|
total_times = []
|
||||||
|
correct_answers = 0
|
||||||
|
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
question = item["question"]
|
||||||
|
ground_truth = item["answer"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Hack: Monkey-patch the ask method to capture internal timing
|
||||||
|
original_ask = self.chat.ask
|
||||||
|
captured_search_time = None
|
||||||
|
captured_generation_time = None
|
||||||
|
|
||||||
|
def patched_ask(*args, **kwargs):
|
||||||
|
nonlocal captured_search_time, captured_generation_time
|
||||||
|
|
||||||
|
# Time the search part
|
||||||
|
search_start = time.time()
|
||||||
|
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
|
||||||
|
captured_search_time = time.time() - search_start
|
||||||
|
|
||||||
|
# Time the generation part
|
||||||
|
context = "\n\n".join([r.text for r in results])
|
||||||
|
prompt = (
|
||||||
|
"Here is some retrieved context that might help answer your question:\n\n"
|
||||||
|
f"{context}\n\n"
|
||||||
|
f"Question: {args[0]}\n\n"
|
||||||
|
"Please provide the best answer you can based on this context and your knowledge."
|
||||||
|
)
|
||||||
|
|
||||||
|
generation_start = time.time()
|
||||||
|
answer = self.chat.llm.ask(prompt)
|
||||||
|
captured_generation_time = time.time() - generation_start
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
self.chat.ask = patched_ask
|
||||||
|
|
||||||
|
# Time the total QA
|
||||||
|
total_start = time.time()
|
||||||
|
generated_answer = self.chat.ask(question)
|
||||||
|
total_time = time.time() - total_start
|
||||||
|
|
||||||
|
# Restore original method
|
||||||
|
self.chat.ask = original_ask
|
||||||
|
|
||||||
|
# Store the timings
|
||||||
|
search_times.append(captured_search_time)
|
||||||
|
generation_times.append(captured_generation_time)
|
||||||
|
total_times.append(total_time)
|
||||||
|
|
||||||
|
# Check accuracy using LLM as judge
|
||||||
|
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
|
||||||
|
if is_correct:
|
||||||
|
correct_answers += 1
|
||||||
|
|
||||||
|
status = "✅" if is_correct else "❌"
|
||||||
|
print(
|
||||||
|
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
|
||||||
|
)
|
||||||
|
print(f" GT: {ground_truth}")
|
||||||
|
print(f" Gen: {generated_answer[:100]}...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Error: {e}")
|
||||||
|
search_times.append(0.0)
|
||||||
|
generation_times.append(0.0)
|
||||||
|
total_times.append(0.0)
|
||||||
|
|
||||||
|
accuracy = correct_answers / len(data) if data else 0.0
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"total_questions": len(data),
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
|
||||||
|
"avg_generation_time": sum(generation_times) / len(generation_times)
|
||||||
|
if generation_times
|
||||||
|
else 0.0,
|
||||||
|
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"correct_answers": correct_answers,
|
||||||
|
"search_times": search_times,
|
||||||
|
"generation_times": generation_times,
|
||||||
|
"total_times": total_times,
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _check_answer_accuracy(
|
||||||
|
self, generated_answer: str, ground_truth: str, question: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if generated answer matches ground truth using LLM as judge"""
|
||||||
|
judge_prompt = f"""You are an expert judge evaluating financial question answering.
|
||||||
|
|
||||||
|
Question: {question}
|
||||||
|
|
||||||
|
Ground Truth Answer: {ground_truth}
|
||||||
|
|
||||||
|
Generated Answer: {generated_answer}
|
||||||
|
|
||||||
|
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
|
||||||
|
1. Numerical accuracy (exact values, units, currency)
|
||||||
|
2. Key financial concepts and terminology
|
||||||
|
3. Overall factual correctness
|
||||||
|
|
||||||
|
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
|
||||||
|
|
||||||
|
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
judge_response = self.openai_client.chat.completions.create(
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
messages=[{"role": "user", "content": judge_prompt}],
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
judgment = judge_response.choices[0].message.content.strip().upper()
|
||||||
|
return judgment == "CORRECT"
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Judge error: {e}, falling back to string matching")
|
||||||
|
# Fallback to simple string matching
|
||||||
|
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
|
||||||
|
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
|
||||||
|
return gt_clean in gen_clean
|
||||||
|
|
||||||
|
def _print_results(self, timing_metrics: dict):
|
||||||
|
"""Print evaluation results"""
|
||||||
|
print("\n🎯 EVALUATION RESULTS")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Index comparison analysis
|
||||||
|
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||||
|
print("\n📏 Index Comparison Analysis:")
|
||||||
|
current = timing_metrics["current_index"]
|
||||||
|
non_compact = timing_metrics["non_compact_index"]
|
||||||
|
|
||||||
|
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
|
||||||
|
print(
|
||||||
|
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(" Component breakdown (non-compact):")
|
||||||
|
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||||
|
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||||
|
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||||
|
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
if "performance_comparison" in timing_metrics:
|
||||||
|
perf = timing_metrics["performance_comparison"]
|
||||||
|
print("\n⚡ Performance Comparison:")
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
# Legacy single index analysis (fallback)
|
||||||
|
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||||
|
print("\n📏 Index Size Analysis:")
|
||||||
|
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
|
||||||
|
|
||||||
|
print("\n📊 Accuracy:")
|
||||||
|
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
|
||||||
|
print(
|
||||||
|
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Timing Breakdown:")
|
||||||
|
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||||
|
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||||
|
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||||
|
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
|
||||||
|
|
||||||
|
if timing_metrics.get("avg_total_time", 0) > 0:
|
||||||
|
search_pct = (
|
||||||
|
timing_metrics.get("avg_search_time", 0)
|
||||||
|
/ timing_metrics.get("avg_total_time", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
gen_pct = (
|
||||||
|
timing_metrics.get("avg_generation_time", 0)
|
||||||
|
/ timing_metrics.get("avg_total_time", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print("\n📈 Time Distribution:")
|
||||||
|
print(f" Search: {search_pct:.1f}%")
|
||||||
|
print(f" Generation: {gen_pct:.1f}%")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if baseline exists
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_financebench.py first to build the baseline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.stage == "2" or args.stage == "all":
|
||||||
|
# Stage 2: Recall@3 evaluation
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||||
|
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load FinanceBench queries for testing
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
queries = []
|
||||||
|
with open(args.dataset, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
|
||||||
|
# Test with more queries for robust measurement
|
||||||
|
test_queries = queries[:2000]
|
||||||
|
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||||
|
|
||||||
|
# Test with complexity 64
|
||||||
|
complexity = 64
|
||||||
|
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
|
||||||
|
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
# Shared non-compact index path for Stage 3 and 4
|
||||||
|
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
if args.stage == "3" or args.stage == "all":
|
||||||
|
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
|
||||||
|
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||||
|
print(
|
||||||
|
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create non-compact index for binary search (will be reused in Stage 4)
|
||||||
|
print("🏗️ Creating non-compact index for binary search...")
|
||||||
|
evaluator = FinanceBenchEvaluator(args.index)
|
||||||
|
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact index for binary search
|
||||||
|
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load queries for testing
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
queries = []
|
||||||
|
with open(args.dataset, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
|
||||||
|
# Use more queries for robust measurement
|
||||||
|
test_queries = queries[:200]
|
||||||
|
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||||
|
|
||||||
|
# Binary search for 90% recall complexity (without recompute for speed)
|
||||||
|
target_recall = 0.9
|
||||||
|
min_complexity, max_complexity = 1, 32
|
||||||
|
|
||||||
|
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||||
|
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||||
|
|
||||||
|
best_complexity = None
|
||||||
|
best_recall = 0.0
|
||||||
|
|
||||||
|
while min_complexity <= max_complexity:
|
||||||
|
mid_complexity = (min_complexity + max_complexity) // 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||||
|
)
|
||||||
|
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries, mid_complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recall >= target_recall:
|
||||||
|
best_complexity = mid_complexity
|
||||||
|
best_recall = recall
|
||||||
|
max_complexity = mid_complexity - 1
|
||||||
|
print(" ✅ Target reached! Searching for lower complexity...")
|
||||||
|
else:
|
||||||
|
min_complexity = mid_complexity + 1
|
||||||
|
print(" ❌ Below target. Searching for higher complexity...")
|
||||||
|
|
||||||
|
if best_complexity is not None:
|
||||||
|
print("\n🎯 Optimal complexity found!")
|
||||||
|
print(f" Complexity: {best_complexity}")
|
||||||
|
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||||
|
|
||||||
|
# Test a few complexities around the optimal one for verification
|
||||||
|
print("\n🔬 Verification test around optimal complexity:")
|
||||||
|
verification_complexities = [
|
||||||
|
max(1, best_complexity - 2),
|
||||||
|
max(1, best_complexity - 1),
|
||||||
|
best_complexity,
|
||||||
|
best_complexity + 1,
|
||||||
|
best_complexity + 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for complexity in verification_complexities:
|
||||||
|
if complexity <= 512: # reasonable upper bound
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries, complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
status = "✅" if recall >= target_recall else "❌"
|
||||||
|
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||||
|
|
||||||
|
# Now test the optimal complexity with compact index and recompute for comparison
|
||||||
|
print(
|
||||||
|
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||||
|
)
|
||||||
|
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||||
|
test_queries[:10], best_complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||||
|
)
|
||||||
|
complexity = best_complexity
|
||||||
|
print(
|
||||||
|
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||||
|
)
|
||||||
|
compact_evaluator.cleanup()
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||||
|
print("All tested complexities were below target.")
|
||||||
|
|
||||||
|
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||||
|
binary_search_evaluator.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
|
||||||
|
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||||
|
|
||||||
|
if args.stage == "4" or args.stage == "all":
|
||||||
|
# Stage 4: Comprehensive evaluation with dual index comparison
|
||||||
|
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
|
||||||
|
|
||||||
|
# Use FinanceBench evaluator for QA evaluation
|
||||||
|
evaluator = FinanceBenchEvaluator(
|
||||||
|
args.index, args.openai_api_key if args.llm_backend == "openai" else None
|
||||||
|
)
|
||||||
|
|
||||||
|
print("📖 Loading FinanceBench dataset...")
|
||||||
|
data = evaluator.load_dataset(args.dataset)
|
||||||
|
|
||||||
|
# Step 1: Analyze current (compact) index
|
||||||
|
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||||
|
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||||
|
compact_size_metrics["index_type"] = "compact"
|
||||||
|
|
||||||
|
# Step 2: Use existing non-compact index or create if needed
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
print(
|
||||||
|
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||||
|
)
|
||||||
|
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||||
|
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_size_metrics["index_type"] = "non_compact"
|
||||||
|
else:
|
||||||
|
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||||
|
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||||
|
non_compact_index_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Compare index sizes
|
||||||
|
print("\n📊 Index size comparison:")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||||
|
)
|
||||||
|
print("\n📊 Index-only size comparison (.index file only):")
|
||||||
|
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
|
||||||
|
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
|
||||||
|
# Use index-only size for fair comparison (same as Enron emails)
|
||||||
|
storage_saving = (
|
||||||
|
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
|
||||||
|
/ non_compact_size_metrics["index_only_mb"]
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||||
|
|
||||||
|
# Step 4: Performance comparison between the two indexes
|
||||||
|
if complexity is None:
|
||||||
|
raise ValueError("Complexity is required for performance comparison")
|
||||||
|
|
||||||
|
print("\n⚡ Performance comparison between indexes...")
|
||||||
|
performance_metrics = evaluator.compare_index_performance(
|
||||||
|
non_compact_index_path, args.index, data[:10], complexity=complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Generation evaluation
|
||||||
|
test_samples = 20
|
||||||
|
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
|
||||||
|
|
||||||
|
if args.llm_backend == "openai" and args.openai_api_key:
|
||||||
|
print("🔍🤖 Running OpenAI-based generation evaluation...")
|
||||||
|
evaluation_start = time.time()
|
||||||
|
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
|
||||||
|
evaluation_time = time.time() - evaluation_start
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Load LLM
|
||||||
|
if args.llm_backend == "hf":
|
||||||
|
tokenizer, model = load_hf_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_hf(tokenizer, model, prompt)
|
||||||
|
else: # vllm
|
||||||
|
llm, sampling_params = load_vllm_model(args.model_name)
|
||||||
|
|
||||||
|
def llm_func(prompt):
|
||||||
|
return generate_vllm(llm, sampling_params, prompt)
|
||||||
|
|
||||||
|
# Simple generation evaluation
|
||||||
|
queries = [item["question"] for item in data[:test_samples]]
|
||||||
|
gen_results = evaluate_rag(
|
||||||
|
evaluator.searcher,
|
||||||
|
llm_func,
|
||||||
|
queries,
|
||||||
|
domain="finance",
|
||||||
|
complexity=complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
timing_metrics = {
|
||||||
|
"total_questions": len(queries),
|
||||||
|
"avg_search_time": gen_results["avg_search_time"],
|
||||||
|
"avg_generation_time": gen_results["avg_generation_time"],
|
||||||
|
"results": gen_results["results"],
|
||||||
|
}
|
||||||
|
evaluation_time = time.time()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Generation evaluation failed: {e}")
|
||||||
|
timing_metrics = {
|
||||||
|
"total_questions": 0,
|
||||||
|
"avg_search_time": 0,
|
||||||
|
"avg_generation_time": 0,
|
||||||
|
}
|
||||||
|
evaluation_time = 0
|
||||||
|
|
||||||
|
# Combine all metrics
|
||||||
|
combined_metrics = {
|
||||||
|
**timing_metrics,
|
||||||
|
"total_evaluation_time": evaluation_time,
|
||||||
|
"current_index": compact_size_metrics,
|
||||||
|
"non_compact_index": non_compact_size_metrics,
|
||||||
|
"performance_comparison": performance_metrics,
|
||||||
|
"storage_saving_percent": storage_saving,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n📊 Generation Results:")
|
||||||
|
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||||
|
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||||
|
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.output:
|
||||||
|
print(f"\n💾 Saving results to {args.output}...")
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(combined_metrics, f, indent=2, default=str)
|
||||||
|
print(f"✅ Results saved to {args.output}")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage == "all":
|
||||||
|
print("🎉 All evaluation stages completed successfully!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" Stage 2: ✅ Recall@3 evaluation completed")
|
||||||
|
print(" Stage 3: ✅ Optimal complexity found")
|
||||||
|
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
|
||||||
|
print("\n🔧 Recommended next steps:")
|
||||||
|
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||||
|
print(" - Review accuracy and timing breakdown for performance optimization")
|
||||||
|
print(" - Run full evaluation on complete dataset if needed")
|
||||||
|
|
||||||
|
# Clean up non-compact index after all stages complete
|
||||||
|
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
temp_index_dir = Path(non_compact_index_path).parent
|
||||||
|
temp_index_name = Path(non_compact_index_path).name
|
||||||
|
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||||
|
temp_file.unlink()
|
||||||
|
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print("📝 No temporary index to clean up")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Evaluation interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
462
benchmarks/financebench/setup_financebench.py
Executable file
462
benchmarks/financebench/setup_financebench.py
Executable file
@@ -0,0 +1,462 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
FinanceBench Complete Setup Script
|
||||||
|
Downloads all PDFs and builds full LEANN datastore
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
import pymupdf
|
||||||
|
import requests
|
||||||
|
from leann import LeannBuilder, LeannSearcher
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class FinanceBenchSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||||
|
self.data_dir = self.base_dir / data_dir
|
||||||
|
self.pdf_dir = self.data_dir / "pdfs"
|
||||||
|
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||||
|
self.index_dir = self.data_dir / "index"
|
||||||
|
self.download_lock = Lock()
|
||||||
|
|
||||||
|
def download_dataset(self):
|
||||||
|
"""Download the main FinanceBench dataset"""
|
||||||
|
print("📊 Downloading FinanceBench dataset...")
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
if self.dataset_file.exists():
|
||||||
|
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(self.dataset_file, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||||
|
|
||||||
|
def get_pdf_list(self):
|
||||||
|
"""Get list of all PDF files from GitHub"""
|
||||||
|
print("📋 Fetching PDF list from GitHub...")
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
pdf_files = response.json()
|
||||||
|
|
||||||
|
print(f"Found {len(pdf_files)} PDF files")
|
||||||
|
return pdf_files
|
||||||
|
|
||||||
|
def download_single_pdf(self, pdf_info, position):
|
||||||
|
"""Download a single PDF file"""
|
||||||
|
pdf_name = pdf_info["name"]
|
||||||
|
pdf_path = self.pdf_dir / pdf_name
|
||||||
|
|
||||||
|
# Skip if already downloaded
|
||||||
|
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||||
|
return f"✅ {pdf_name} (cached)"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Download PDF
|
||||||
|
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Write to file
|
||||||
|
with self.download_lock:
|
||||||
|
with open(pdf_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ {pdf_name}: {e!s}"
|
||||||
|
|
||||||
|
def download_all_pdfs(self, max_workers: int = 5):
|
||||||
|
"""Download all PDF files with parallel processing"""
|
||||||
|
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pdf_files = self.get_pdf_list()
|
||||||
|
|
||||||
|
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# Submit all download tasks
|
||||||
|
future_to_pdf = {
|
||||||
|
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||||
|
for i, pdf_info in enumerate(pdf_files)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process completed downloads with progress bar
|
||||||
|
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||||
|
for future in as_completed(future_to_pdf):
|
||||||
|
result = future.result()
|
||||||
|
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Verify downloads
|
||||||
|
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||||
|
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||||
|
|
||||||
|
# Show any failures
|
||||||
|
missing_pdfs = []
|
||||||
|
for pdf_info in pdf_files:
|
||||||
|
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||||
|
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||||
|
missing_pdfs.append(pdf_info["name"])
|
||||||
|
|
||||||
|
if missing_pdfs:
|
||||||
|
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||||
|
for pdf in missing_pdfs[:5]: # Show first 5
|
||||||
|
print(f" - {pdf}")
|
||||||
|
if len(missing_pdfs) > 5:
|
||||||
|
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||||
|
|
||||||
|
def build_leann_index(
|
||||||
|
self,
|
||||||
|
backend: str = "hnsw",
|
||||||
|
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||||
|
):
|
||||||
|
"""Build LEANN index from all PDFs"""
|
||||||
|
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
# Check if we have PDFs
|
||||||
|
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||||
|
if not pdf_files:
|
||||||
|
raise RuntimeError("No PDF files found! Run download first.")
|
||||||
|
|
||||||
|
print(f"Found {len(pdf_files)} PDF files to process")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Initialize builder with standard compact configuration
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||||
|
is_compact=True, # Enable compact storage (pruned)
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process PDFs and extract text
|
||||||
|
total_chunks = 0
|
||||||
|
failed_pdfs = []
|
||||||
|
|
||||||
|
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||||
|
try:
|
||||||
|
chunks = self.extract_pdf_text(pdf_path)
|
||||||
|
for chunk in chunks:
|
||||||
|
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||||
|
total_chunks += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||||
|
failed_pdfs.append(pdf_path.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build index in index directory
|
||||||
|
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||||
|
print(f"🔨 Building index: {index_path}")
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
|
||||||
|
print("✅ Index built successfully!")
|
||||||
|
print(f" 📁 Index path: {index_path}")
|
||||||
|
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||||
|
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||||
|
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||||
|
|
||||||
|
if failed_pdfs:
|
||||||
|
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||||
|
|
||||||
|
return str(index_path)
|
||||||
|
|
||||||
|
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
|
||||||
|
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
|
||||||
|
print("🔨 Building FAISS Flat baseline...")
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann.api import compute_embeddings
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Read metadata from the built index
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
import json
|
||||||
|
|
||||||
|
meta = json.loads(f.read())
|
||||||
|
|
||||||
|
embedding_model = meta["embedding_model"]
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not os.path.isabs(passage_file):
|
||||||
|
index_dir = os.path.dirname(index_path)
|
||||||
|
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||||
|
|
||||||
|
print(f"📊 Loading passages from {passage_file}...")
|
||||||
|
print(f"🤖 Using embedding model: {embedding_model}")
|
||||||
|
|
||||||
|
# Load all passages for baseline
|
||||||
|
passages = []
|
||||||
|
passage_ids = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"])
|
||||||
|
|
||||||
|
print(f"📄 Loaded {len(passages)} passages")
|
||||||
|
|
||||||
|
# Compute embeddings using the same method as LEANN
|
||||||
|
print("🧮 Computing embeddings...")
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
passages,
|
||||||
|
embedding_model,
|
||||||
|
mode="sentence-transformers",
|
||||||
|
use_server=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Build FAISS flat index
|
||||||
|
print("🏗️ Building FAISS IndexFlatIP...")
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# Add embeddings to flat index
|
||||||
|
embeddings_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||||
|
|
||||||
|
# Save index and metadata
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as f:
|
||||||
|
pickle.dump(passage_ids, f)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved to {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||||
|
"""Extract and chunk text from a PDF file"""
|
||||||
|
chunks = []
|
||||||
|
doc = pymupdf.open(pdf_path)
|
||||||
|
|
||||||
|
for page_num in range(len(doc)):
|
||||||
|
page = doc[page_num]
|
||||||
|
text = page.get_text() # type: ignore
|
||||||
|
|
||||||
|
if not text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = {
|
||||||
|
"source_file": pdf_path.name,
|
||||||
|
"page_number": page_num + 1,
|
||||||
|
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||||
|
"company": pdf_path.name.split("_")[0],
|
||||||
|
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use recursive character splitting like LangChain
|
||||||
|
if len(text.split()) > 500:
|
||||||
|
# Split by double newlines (paragraphs)
|
||||||
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||||
|
|
||||||
|
current_chunk = ""
|
||||||
|
for para in paragraphs:
|
||||||
|
# If adding this paragraph would make chunk too long, save current chunk
|
||||||
|
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": current_chunk.strip(),
|
||||||
|
"metadata": {
|
||||||
|
**metadata,
|
||||||
|
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
current_chunk = para
|
||||||
|
else:
|
||||||
|
current_chunk = (current_chunk + " " + para).strip()
|
||||||
|
|
||||||
|
# Add the last chunk
|
||||||
|
if current_chunk.strip():
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": current_chunk.strip(),
|
||||||
|
"metadata": {
|
||||||
|
**metadata,
|
||||||
|
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Page is short enough, use as single chunk
|
||||||
|
chunks.append(
|
||||||
|
{
|
||||||
|
"text": text.strip(),
|
||||||
|
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def extract_year_from_filename(self, filename: str) -> str:
|
||||||
|
"""Extract year from PDF filename"""
|
||||||
|
# Try to find 4-digit year in filename
|
||||||
|
|
||||||
|
match = re.search(r"(\d{4})", filename)
|
||||||
|
return match.group(1) if match else "unknown"
|
||||||
|
|
||||||
|
def verify_setup(self, index_path: str):
|
||||||
|
"""Verify the setup by testing a simple query"""
|
||||||
|
print("🧪 Verifying setup with test query...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Test query
|
||||||
|
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||||
|
results = searcher.search(test_query, top_k=3)
|
||||||
|
|
||||||
|
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
company = result.metadata.get("company", "Unknown")
|
||||||
|
year = result.metadata.get("doc_period", "Unknown")
|
||||||
|
page = result.metadata.get("page_number", "Unknown")
|
||||||
|
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||||
|
print(f" {result.text[:100]}...")
|
||||||
|
|
||||||
|
searcher.cleanup()
|
||||||
|
print("✅ Setup verification completed successfully!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Setup verification failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-model",
|
||||||
|
default="sentence-transformers/all-mpnet-base-v2",
|
||||||
|
help="Embedding model",
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||||
|
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||||
|
parser.add_argument(
|
||||||
|
"--build-baseline-only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only build FAISS baseline from existing index",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("🏦 FinanceBench Complete Setup")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
setup = FinanceBenchSetup(args.data_dir)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if args.build_baseline_only:
|
||||||
|
# Only build baseline from existing index
|
||||||
|
index_path = setup.index_dir / f"financebench_full_{args.backend}"
|
||||||
|
index_file = f"{index_path}.index"
|
||||||
|
meta_file = f"{index_path}.leann.meta.json"
|
||||||
|
|
||||||
|
if not os.path.exists(index_file) or not os.path.exists(meta_file):
|
||||||
|
print("❌ Index files not found:")
|
||||||
|
print(f" Index: {index_file}")
|
||||||
|
print(f" Meta: {meta_file}")
|
||||||
|
print("💡 Run without --build-baseline-only to build the index first")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
print(f"🔨 Building baseline from existing index: {index_path}")
|
||||||
|
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
|
||||||
|
print(f"✅ Baseline built at {baseline_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 1: Download dataset
|
||||||
|
setup.download_dataset()
|
||||||
|
|
||||||
|
# Step 2: Download PDFs
|
||||||
|
if not args.skip_download:
|
||||||
|
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping PDF download")
|
||||||
|
|
||||||
|
# Step 3: Build LEANN index
|
||||||
|
if not args.skip_build:
|
||||||
|
index_path = setup.build_leann_index(
|
||||||
|
backend=args.backend, embedding_model=args.embedding_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Build FAISS flat baseline
|
||||||
|
print("\n🔨 Building FAISS flat baseline...")
|
||||||
|
baseline_path = setup.build_faiss_flat_baseline(index_path)
|
||||||
|
print(f"✅ Baseline built at {baseline_path}")
|
||||||
|
|
||||||
|
# Step 5: Verify setup
|
||||||
|
setup.verify_setup(index_path)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping index building")
|
||||||
|
|
||||||
|
print("\n🎉 FinanceBench setup completed!")
|
||||||
|
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||||
|
print("\nNext steps:")
|
||||||
|
print(
|
||||||
|
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Setup interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Setup failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
214
benchmarks/financebench/verify_recall.py
Normal file
214
benchmarks/financebench/verify_recall.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.9"
|
||||||
|
# dependencies = [
|
||||||
|
# "faiss-cpu",
|
||||||
|
# "numpy",
|
||||||
|
# "sentence-transformers",
|
||||||
|
# "torch",
|
||||||
|
# "tqdm",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
"""
|
||||||
|
Independent recall verification script using standard FAISS.
|
||||||
|
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Direct embedding computation using sentence-transformers.
|
||||||
|
Copied logic to avoid dependency issues.
|
||||||
|
"""
|
||||||
|
print(f"Loading model: {model_name}")
|
||||||
|
model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
|
print(f"Computing embeddings for {len(chunks)} chunks...")
|
||||||
|
embeddings = model.encode(
|
||||||
|
chunks,
|
||||||
|
show_progress_bar=True,
|
||||||
|
batch_size=32,
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return embeddings.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
|
||||||
|
"""Load FinanceBench queries from dataset"""
|
||||||
|
queries = []
|
||||||
|
with open(dataset_path, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
queries.append(data["question"])
|
||||||
|
if len(queries) >= max_queries:
|
||||||
|
break
|
||||||
|
return queries
|
||||||
|
|
||||||
|
|
||||||
|
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
|
||||||
|
"""Load passages from LEANN index structure"""
|
||||||
|
meta_path = f"{index_path}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
index_dir = Path(index_path).parent
|
||||||
|
passage_file = index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"Loading passages from {passage_file}")
|
||||||
|
|
||||||
|
passages = []
|
||||||
|
passage_ids = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in tqdm(f, desc="Loading passages"):
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
passages.append(data["text"])
|
||||||
|
passage_ids.append(data["id"])
|
||||||
|
|
||||||
|
print(f"Loaded {len(passages)} passages")
|
||||||
|
return passages, passage_ids
|
||||||
|
|
||||||
|
|
||||||
|
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
|
||||||
|
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
|
||||||
|
# Build Flat index (ground truth)
|
||||||
|
print("Building FAISS IndexFlatIP (ground truth)...")
|
||||||
|
flat_index = faiss.IndexFlatIP(dimension)
|
||||||
|
flat_index.add(embeddings)
|
||||||
|
|
||||||
|
# Build HNSW index
|
||||||
|
print("Building FAISS IndexHNSWFlat...")
|
||||||
|
M = 32 # Same as LEANN default
|
||||||
|
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
|
||||||
|
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
|
||||||
|
hnsw_index.add(embeddings)
|
||||||
|
|
||||||
|
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
|
||||||
|
return flat_index, hnsw_index
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_recall_at_k(
|
||||||
|
query_embeddings: np.ndarray,
|
||||||
|
flat_index: faiss.Index,
|
||||||
|
hnsw_index: faiss.Index,
|
||||||
|
passage_ids: list[str],
|
||||||
|
k: int = 3,
|
||||||
|
ef_search: int = 64,
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@k comparing HNSW vs Flat"""
|
||||||
|
|
||||||
|
# Set search parameters for HNSW
|
||||||
|
hnsw_index.hnsw.efSearch = ef_search
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = query_embeddings.shape[0]
|
||||||
|
|
||||||
|
for i in range(num_queries):
|
||||||
|
query = query_embeddings[i : i + 1] # Keep 2D shape
|
||||||
|
|
||||||
|
# Get ground truth from Flat index (standard FAISS API)
|
||||||
|
flat_distances, flat_indices = flat_index.search(query, k)
|
||||||
|
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
|
||||||
|
|
||||||
|
# Get results from HNSW index (standard FAISS API)
|
||||||
|
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
|
||||||
|
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
|
||||||
|
|
||||||
|
# Calculate recall
|
||||||
|
intersection = ground_truth_ids.intersection(hnsw_ids)
|
||||||
|
recall = len(intersection) / k
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
|
||||||
|
print(f" Flat: {list(ground_truth_ids)}")
|
||||||
|
print(f" HNSW: {list(hnsw_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Configuration
|
||||||
|
dataset_path = "data/financebench_merged.jsonl"
|
||||||
|
index_path = "data/index/financebench_full_hnsw.leann"
|
||||||
|
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
||||||
|
|
||||||
|
print("🔍 FAISS Recall Verification")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# Check if files exist
|
||||||
|
if not Path(dataset_path).exists():
|
||||||
|
print(f"❌ Dataset not found: {dataset_path}")
|
||||||
|
return
|
||||||
|
if not Path(f"{index_path}.meta.json").exists():
|
||||||
|
print(f"❌ Index metadata not found: {index_path}.meta.json")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
print("📖 Loading FinanceBench queries...")
|
||||||
|
queries = load_financebench_queries(dataset_path, max_queries=50)
|
||||||
|
print(f"Loaded {len(queries)} queries")
|
||||||
|
|
||||||
|
print("📄 Loading passages from LEANN index...")
|
||||||
|
passages, passage_ids = load_passages_from_leann_index(index_path)
|
||||||
|
|
||||||
|
# Compute embeddings
|
||||||
|
print("🧮 Computing passage embeddings...")
|
||||||
|
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
|
||||||
|
|
||||||
|
print("🧮 Computing query embeddings...")
|
||||||
|
query_embeddings = compute_embeddings_direct(queries, embedding_model)
|
||||||
|
|
||||||
|
# Build FAISS indexes
|
||||||
|
print("🏗️ Building FAISS indexes...")
|
||||||
|
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
|
||||||
|
|
||||||
|
# Test different efSearch values (equivalent to LEANN complexity)
|
||||||
|
print("\n📊 Evaluating Recall@3 at different efSearch values...")
|
||||||
|
ef_search_values = [16, 32, 64, 128, 256]
|
||||||
|
|
||||||
|
for ef_search in ef_search_values:
|
||||||
|
print(f"\n🧪 Testing efSearch = {ef_search}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
recall = evaluate_recall_at_k(
|
||||||
|
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(
|
||||||
|
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n✅ Verification completed!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" - Built independent FAISS Flat and HNSW indexes")
|
||||||
|
print(" - Compared recall@3 at different efSearch values")
|
||||||
|
print(" - Used same embedding model as LEANN")
|
||||||
|
print(" - This validates LEANN's recall measurements")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
benchmarks/laion/.gitignore
vendored
Normal file
1
benchmarks/laion/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
data/
|
||||||
199
benchmarks/laion/README.md
Normal file
199
benchmarks/laion/README.md
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
# LAION Multimodal Benchmark
|
||||||
|
|
||||||
|
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This benchmark evaluates:
|
||||||
|
- **Image retrieval timing** using caption-based queries
|
||||||
|
- **Recall@K performance** for image search
|
||||||
|
- **Complexity analysis** across different search parameters
|
||||||
|
- **Index size and storage efficiency**
|
||||||
|
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
|
||||||
|
|
||||||
|
## Dataset Configuration
|
||||||
|
|
||||||
|
- **Dataset**: LAION-400M subset (10,000 images)
|
||||||
|
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
|
||||||
|
- **Queries**: 200 random captions from the dataset
|
||||||
|
- **Ground Truth**: Self-recall (query caption → original image)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Setup the benchmark
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd benchmarks/laion
|
||||||
|
python setup_laion.py --num-samples 10000 --num-queries 200
|
||||||
|
```
|
||||||
|
|
||||||
|
This will:
|
||||||
|
- Create dummy LAION data (10K samples)
|
||||||
|
- Generate CLIP embeddings (512-dim)
|
||||||
|
- Build LEANN index with HNSW backend
|
||||||
|
- Create 200 evaluation queries
|
||||||
|
|
||||||
|
### 2. Run evaluation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all evaluation stages
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann
|
||||||
|
|
||||||
|
# Run specific stages
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
|
||||||
|
|
||||||
|
# Multimodal generation with Qwen2.5-VL
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Save results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python evaluate_laion.py --index data/laion_index.leann --output results.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### Setup Options
|
||||||
|
```bash
|
||||||
|
python setup_laion.py \
|
||||||
|
--num-samples 10000 \
|
||||||
|
--num-queries 200 \
|
||||||
|
--index-path data/laion_index.leann \
|
||||||
|
--backend hnsw
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation Options
|
||||||
|
```bash
|
||||||
|
python evaluate_laion.py \
|
||||||
|
--index data/laion_index.leann \
|
||||||
|
--queries data/evaluation_queries.jsonl \
|
||||||
|
--complexity 64 \
|
||||||
|
--top-k 3 \
|
||||||
|
--num-samples 100 \
|
||||||
|
--stage all
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Stages
|
||||||
|
|
||||||
|
### Stage 2: Recall Evaluation
|
||||||
|
- Evaluates Recall@3 for multimodal retrieval
|
||||||
|
- Compares LEANN vs FAISS baseline performance
|
||||||
|
- Self-recall: query caption should retrieve original image
|
||||||
|
|
||||||
|
### Stage 3: Complexity Analysis
|
||||||
|
- Binary search for optimal complexity (90% recall target)
|
||||||
|
- Tests performance across different complexity levels
|
||||||
|
- Analyzes speed vs. accuracy tradeoffs
|
||||||
|
|
||||||
|
### Stage 4: Index Comparison
|
||||||
|
- Compares compact vs non-compact index sizes
|
||||||
|
- Measures search performance differences
|
||||||
|
- Reports storage efficiency and speed ratios
|
||||||
|
|
||||||
|
### Stage 5: Multimodal Generation
|
||||||
|
- Uses Qwen2.5-VL for image understanding and description
|
||||||
|
- Retrieval-Augmented Generation (RAG) with multimodal context
|
||||||
|
- Measures both search and generation timing
|
||||||
|
|
||||||
|
## Output Metrics
|
||||||
|
|
||||||
|
### Timing Metrics
|
||||||
|
- Average/median/min/max search time
|
||||||
|
- Standard deviation
|
||||||
|
- Searches per second
|
||||||
|
- Latency in milliseconds
|
||||||
|
|
||||||
|
### Recall Metrics
|
||||||
|
- Recall@3 percentage for image retrieval
|
||||||
|
- Number of queries with ground truth
|
||||||
|
|
||||||
|
### Index Metrics
|
||||||
|
- Total index size (MB)
|
||||||
|
- Component breakdown (index, passages, metadata)
|
||||||
|
- Storage savings (compact vs non-compact)
|
||||||
|
- Backend and embedding model info
|
||||||
|
|
||||||
|
### Generation Metrics (Stage 5)
|
||||||
|
- Average search time per query
|
||||||
|
- Average generation time per query
|
||||||
|
- Time distribution (search vs generation)
|
||||||
|
- Sample multimodal responses
|
||||||
|
- Model: Qwen2.5-VL performance
|
||||||
|
|
||||||
|
## Benchmark Results
|
||||||
|
|
||||||
|
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
|
||||||
|
|
||||||
|
**Stage 3: Optimal Complexity Analysis**
|
||||||
|
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
|
||||||
|
- **Binary Search Range**: 1-128
|
||||||
|
- **Target Recall**: 90%
|
||||||
|
- **Index Type**: Non-compact (for fast binary search)
|
||||||
|
|
||||||
|
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
|
||||||
|
- **Total Queries**: 20
|
||||||
|
- **Average Search Time**: 1.200s per query
|
||||||
|
- **Average Generation Time**: 6.558s per query
|
||||||
|
- **Time Distribution**: Search 15.5%, Generation 84.5%
|
||||||
|
- **LLM Backend**: HuggingFace transformers
|
||||||
|
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
- **Optimal Complexity**: 85
|
||||||
|
|
||||||
|
**System Performance:**
|
||||||
|
- **Index Size**: ~10,000 image embeddings from LAION subset
|
||||||
|
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
|
||||||
|
- **Backend**: HNSW with cosine distance
|
||||||
|
|
||||||
|
### Example Results
|
||||||
|
|
||||||
|
```
|
||||||
|
🎯 LAION MULTIMODAL BENCHMARK RESULTS
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
📊 Multimodal Generation Results:
|
||||||
|
Total Queries: 20
|
||||||
|
Avg Search Time: 1.200s
|
||||||
|
Avg Generation Time: 6.558s
|
||||||
|
Time Distribution: Search 15.5%, Generation 84.5%
|
||||||
|
LLM Backend: HuggingFace transformers
|
||||||
|
Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
|
||||||
|
⚙️ Optimal Complexity Analysis:
|
||||||
|
Target Recall: 90%
|
||||||
|
Optimal Complexity: 85
|
||||||
|
Binary Search Range: 1-128
|
||||||
|
Non-compact Index (fast search, no recompute)
|
||||||
|
|
||||||
|
🚀 Performance Summary:
|
||||||
|
Multimodal RAG: 7.758s total per query
|
||||||
|
Search: 15.5% of total time
|
||||||
|
Generation: 84.5% of total time
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
benchmarks/laion/
|
||||||
|
├── setup_laion.py # Setup script
|
||||||
|
├── evaluate_laion.py # Evaluation script
|
||||||
|
├── README.md # This file
|
||||||
|
└── data/ # Generated data
|
||||||
|
├── laion_images/ # Image files (placeholder)
|
||||||
|
├── laion_metadata.jsonl # Image metadata
|
||||||
|
├── laion_passages.jsonl # LEANN passages
|
||||||
|
├── laion_embeddings.npy # CLIP embeddings
|
||||||
|
├── evaluation_queries.jsonl # Evaluation queries
|
||||||
|
└── laion_index.leann/ # LEANN index files
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Current implementation uses dummy data for demonstration
|
||||||
|
- For real LAION data, implement actual download logic in `setup_laion.py`
|
||||||
|
- CLIP embeddings are randomly generated - replace with real CLIP model for production
|
||||||
|
- Adjust `num_samples` and `num_queries` based on available resources
|
||||||
|
- Consider using `--num-samples` during evaluation for faster testing
|
||||||
725
benchmarks/laion/evaluate_laion.py
Normal file
725
benchmarks/laion/evaluate_laion.py
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
"""
|
||||||
|
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from leann import LeannSearcher
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
|
||||||
|
|
||||||
|
# Setup logging to reduce verbose output
|
||||||
|
logging.basicConfig(level=logging.WARNING)
|
||||||
|
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RecallEvaluator:
|
||||||
|
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
|
||||||
|
|
||||||
|
def __init__(self, index_path: str, baseline_dir: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.baseline_dir = baseline_dir
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
# Load FAISS flat baseline (image embeddings)
|
||||||
|
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||||
|
with open(metadata_path, "rb") as f:
|
||||||
|
self.image_ids = pickle.load(f)
|
||||||
|
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
|
||||||
|
|
||||||
|
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
|
||||||
|
self.st_clip = SentenceTransformer("clip-ViT-L-14")
|
||||||
|
|
||||||
|
def evaluate_recall_at_3(
|
||||||
|
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||||
|
) -> float:
|
||||||
|
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
|
||||||
|
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||||
|
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||||
|
|
||||||
|
total_recall = 0.0
|
||||||
|
num_queries = len(captions)
|
||||||
|
|
||||||
|
for i, caption in enumerate(captions):
|
||||||
|
# Get ground truth: search with FAISS flat using caption text embedding
|
||||||
|
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
|
||||||
|
query_embedding = self.st_clip.encode(
|
||||||
|
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||||
|
n = query_embedding.shape[0] # Number of queries
|
||||||
|
k = 3 # Number of nearest neighbors
|
||||||
|
distances = np.zeros((n, k), dtype=np.float32)
|
||||||
|
labels = np.zeros((n, k), dtype=np.int64)
|
||||||
|
|
||||||
|
self.faiss_index.search(
|
||||||
|
n,
|
||||||
|
faiss.swig_ptr(query_embedding),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(labels),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the results (image IDs from FAISS)
|
||||||
|
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
|
||||||
|
|
||||||
|
# Search with LEANN at specified complexity (using caption as text query)
|
||||||
|
test_results = self.searcher.search(
|
||||||
|
caption,
|
||||||
|
top_k=3,
|
||||||
|
complexity=complexity,
|
||||||
|
recompute_embeddings=recompute_embeddings,
|
||||||
|
)
|
||||||
|
test_ids = {result.id for result in test_results}
|
||||||
|
|
||||||
|
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||||
|
intersection = test_ids.intersection(baseline_ids)
|
||||||
|
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||||
|
total_recall += recall
|
||||||
|
|
||||||
|
if i < 3: # Show first few examples
|
||||||
|
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
|
||||||
|
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||||
|
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||||
|
print(f" Intersection: {list(intersection)}")
|
||||||
|
|
||||||
|
avg_recall = total_recall / num_queries
|
||||||
|
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||||
|
return avg_recall
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if hasattr(self, "searcher"):
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class LAIONEvaluator:
|
||||||
|
def __init__(self, index_path: str):
|
||||||
|
self.index_path = index_path
|
||||||
|
self.searcher = LeannSearcher(index_path)
|
||||||
|
|
||||||
|
def load_queries(self, queries_file: str) -> list[str]:
|
||||||
|
"""Load caption queries from evaluation file"""
|
||||||
|
captions = []
|
||||||
|
with open(queries_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
query_data = json.loads(line)
|
||||||
|
captions.append(query_data["query"])
|
||||||
|
|
||||||
|
print(f"📊 Loaded {len(captions)} caption queries")
|
||||||
|
return captions
|
||||||
|
|
||||||
|
def analyze_index_sizes(self) -> dict:
|
||||||
|
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
|
||||||
|
print("📏 Analyzing index sizes (.index only)...")
|
||||||
|
|
||||||
|
# Get all index-related files
|
||||||
|
index_path = Path(self.index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.stem # Remove .leann extension
|
||||||
|
|
||||||
|
sizes: dict[str, float] = {}
|
||||||
|
|
||||||
|
# Core index files
|
||||||
|
index_file = index_dir / f"{index_name}.index"
|
||||||
|
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||||
|
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||||
|
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||||
|
|
||||||
|
# Core index size (.index only)
|
||||||
|
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||||
|
sizes["index_only_mb"] = index_mb
|
||||||
|
|
||||||
|
# Other files for reference (not counted in index_only_mb)
|
||||||
|
sizes["metadata_mb"] = (
|
||||||
|
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_text_mb"] = (
|
||||||
|
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
sizes["passages_index_mb"] = (
|
||||||
|
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" 📁 .index size: {index_mb:.1f} MB")
|
||||||
|
if sizes["metadata_mb"]:
|
||||||
|
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
|
||||||
|
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
|
||||||
|
print(
|
||||||
|
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||||
|
"""Create a non-compact index for comparison purposes"""
|
||||||
|
print("🏗️ Building non-compact index from existing passages...")
|
||||||
|
|
||||||
|
# Load existing passages from current index
|
||||||
|
from leann import LeannBuilder
|
||||||
|
|
||||||
|
current_index_path = Path(self.index_path)
|
||||||
|
current_index_dir = current_index_path.parent
|
||||||
|
current_index_name = current_index_path.name
|
||||||
|
|
||||||
|
# Read metadata to get passage source
|
||||||
|
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||||
|
with open(meta_path) as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
passage_source = meta["passage_sources"][0]
|
||||||
|
passage_file = passage_source["path"]
|
||||||
|
|
||||||
|
# Convert relative path to absolute
|
||||||
|
if not Path(passage_file).is_absolute():
|
||||||
|
passage_file = current_index_dir / Path(passage_file).name
|
||||||
|
|
||||||
|
print(f"📄 Loading passages from {passage_file}...")
|
||||||
|
|
||||||
|
# Load CLIP embeddings
|
||||||
|
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
|
||||||
|
embeddings = np.load(embeddings_file)
|
||||||
|
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
# Build non-compact index with same passages and embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
is_recompute=False, # Disable recompute (store embeddings)
|
||||||
|
is_compact=False, # Disable compact storage
|
||||||
|
distance_metric="cosine",
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in meta.get("backend_kwargs", {}).items()
|
||||||
|
if k not in ["is_recompute", "is_compact", "distance_metric"]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare ids and add passages
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passage_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
data = json.loads(line)
|
||||||
|
ids.append(str(data["id"]))
|
||||||
|
# Ensure metadata contains the id used by the vector index
|
||||||
|
metadata = {**data.get("metadata", {}), "id": data["id"]}
|
||||||
|
builder.add_text(text=data["text"], metadata=metadata)
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist a pickle for build_index_from_embeddings
|
||||||
|
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||||
|
)
|
||||||
|
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||||
|
|
||||||
|
# Analyze the non-compact index size
|
||||||
|
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||||
|
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_sizes["index_type"] = "non_compact"
|
||||||
|
|
||||||
|
return non_compact_sizes
|
||||||
|
|
||||||
|
def compare_index_performance(
|
||||||
|
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
|
||||||
|
) -> dict:
|
||||||
|
"""Compare performance between non-compact and compact indexes"""
|
||||||
|
print("⚡ Comparing search performance between indexes...")
|
||||||
|
|
||||||
|
# Test queries
|
||||||
|
test_queries = test_captions[:5]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"non_compact": {"search_times": []},
|
||||||
|
"compact": {"search_times": []},
|
||||||
|
"avg_search_times": {},
|
||||||
|
"speed_ratio": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test non-compact index (no recompute)
|
||||||
|
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||||
|
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||||
|
|
||||||
|
for caption in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = non_compact_searcher.search(
|
||||||
|
caption, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["non_compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Test compact index (with recompute)
|
||||||
|
print(" 🔍 Testing compact index (with recompute)...")
|
||||||
|
compact_searcher = LeannSearcher(compact_path)
|
||||||
|
|
||||||
|
for caption in test_queries:
|
||||||
|
start_time = time.time()
|
||||||
|
_ = compact_searcher.search(
|
||||||
|
caption, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
results["compact"]["search_times"].append(search_time)
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
results["avg_search_times"]["non_compact"] = sum(
|
||||||
|
results["non_compact"]["search_times"]
|
||||||
|
) / len(results["non_compact"]["search_times"])
|
||||||
|
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||||
|
results["compact"]["search_times"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance ratio
|
||||||
|
if results["avg_search_times"]["compact"] > 0:
|
||||||
|
results["speed_ratio"] = (
|
||||||
|
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results["speed_ratio"] = float("inf")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||||
|
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
non_compact_searcher.cleanup()
|
||||||
|
compact_searcher.cleanup()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _print_results(self, timing_metrics: dict):
|
||||||
|
"""Print evaluation results"""
|
||||||
|
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Index comparison analysis (prefer .index-only view if present)
|
||||||
|
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||||
|
current = timing_metrics["current_index"]
|
||||||
|
non_compact = timing_metrics["non_compact_index"]
|
||||||
|
|
||||||
|
if "index_only_mb" in current and "index_only_mb" in non_compact:
|
||||||
|
print("\n📏 Index Comparison Analysis (.index only):")
|
||||||
|
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
|
||||||
|
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
# Show excluded components for reference if available
|
||||||
|
if any(
|
||||||
|
k in non_compact
|
||||||
|
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
|
||||||
|
):
|
||||||
|
print(" (passages excluded in totals, shown for reference):")
|
||||||
|
print(
|
||||||
|
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
|
||||||
|
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
|
||||||
|
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to legacy totals if running with older metrics
|
||||||
|
print("\n📏 Index Comparison Analysis:")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||||
|
)
|
||||||
|
print(" Component breakdown (non-compact):")
|
||||||
|
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||||
|
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||||
|
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||||
|
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||||
|
|
||||||
|
# Performance comparison
|
||||||
|
if "performance_comparison" in timing_metrics:
|
||||||
|
perf = timing_metrics["performance_comparison"]
|
||||||
|
print("\n⚡ Performance Comparison:")
|
||||||
|
print(
|
||||||
|
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||||
|
)
|
||||||
|
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
# Legacy single index analysis (fallback)
|
||||||
|
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||||
|
print("\n📏 Index Size Analysis:")
|
||||||
|
print(
|
||||||
|
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
if self.searcher:
|
||||||
|
self.searcher.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
|
||||||
|
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||||
|
parser.add_argument(
|
||||||
|
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stage",
|
||||||
|
choices=["2", "3", "4", "5", "all"],
|
||||||
|
default="all",
|
||||||
|
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||||
|
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||||
|
parser.add_argument("--output", help="Save results to JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-backend",
|
||||||
|
choices=["hf"],
|
||||||
|
default="hf",
|
||||||
|
help="LLM backend (Qwen2.5-VL only supports HF)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if baseline exists
|
||||||
|
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||||
|
if not os.path.exists(baseline_index_path):
|
||||||
|
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||||
|
print("💡 Please run setup_laion.py first to build the baseline")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if args.stage == "2" or args.stage == "all":
|
||||||
|
# Stage 2: Recall@3 evaluation
|
||||||
|
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
|
||||||
|
|
||||||
|
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load caption queries for testing
|
||||||
|
laion_evaluator = LAIONEvaluator(args.index)
|
||||||
|
captions = laion_evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Test with queries for robust measurement
|
||||||
|
test_captions = captions[:100] # Use subset for speed
|
||||||
|
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||||
|
|
||||||
|
# Test with complexity 64
|
||||||
|
complexity = 64
|
||||||
|
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
|
||||||
|
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 2 completed!\n")
|
||||||
|
|
||||||
|
# Shared non-compact index path for Stage 3 and 4
|
||||||
|
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||||
|
complexity = args.complexity
|
||||||
|
|
||||||
|
if args.stage == "3" or args.stage == "all":
|
||||||
|
# Stage 3: Binary search for 90% recall complexity
|
||||||
|
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||||
|
print(
|
||||||
|
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create non-compact index for binary search
|
||||||
|
print("🏗️ Creating non-compact index for binary search...")
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||||
|
|
||||||
|
# Use non-compact index for binary search
|
||||||
|
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||||
|
|
||||||
|
# Load caption queries for testing
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Use subset for robust measurement
|
||||||
|
test_captions = captions[:50] # Smaller subset for binary search speed
|
||||||
|
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||||
|
|
||||||
|
# Binary search for 90% recall complexity
|
||||||
|
target_recall = 0.9
|
||||||
|
min_complexity, max_complexity = 1, 128
|
||||||
|
|
||||||
|
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||||
|
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||||
|
|
||||||
|
best_complexity = None
|
||||||
|
best_recall = 0.0
|
||||||
|
|
||||||
|
while min_complexity <= max_complexity:
|
||||||
|
mid_complexity = (min_complexity + max_complexity) // 2
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||||
|
)
|
||||||
|
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions, mid_complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recall >= target_recall:
|
||||||
|
best_complexity = mid_complexity
|
||||||
|
best_recall = recall
|
||||||
|
max_complexity = mid_complexity - 1
|
||||||
|
print(" ✅ Target reached! Searching for lower complexity...")
|
||||||
|
else:
|
||||||
|
min_complexity = mid_complexity + 1
|
||||||
|
print(" ❌ Below target. Searching for higher complexity...")
|
||||||
|
|
||||||
|
if best_complexity is not None:
|
||||||
|
print("\n🎯 Optimal complexity found!")
|
||||||
|
print(f" Complexity: {best_complexity}")
|
||||||
|
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||||
|
|
||||||
|
# Test a few complexities around the optimal one for verification
|
||||||
|
print("\n🔬 Verification test around optimal complexity:")
|
||||||
|
verification_complexities = [
|
||||||
|
max(1, best_complexity - 2),
|
||||||
|
max(1, best_complexity - 1),
|
||||||
|
best_complexity,
|
||||||
|
best_complexity + 1,
|
||||||
|
best_complexity + 2,
|
||||||
|
]
|
||||||
|
|
||||||
|
for complexity in verification_complexities:
|
||||||
|
if complexity <= 512: # reasonable upper bound
|
||||||
|
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions, complexity, recompute_embeddings=False
|
||||||
|
)
|
||||||
|
status = "✅" if recall >= target_recall else "❌"
|
||||||
|
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||||
|
|
||||||
|
# Now test the optimal complexity with compact index and recompute for comparison
|
||||||
|
print(
|
||||||
|
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||||
|
)
|
||||||
|
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||||
|
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||||
|
test_captions[:10], best_complexity, recompute_embeddings=True
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||||
|
)
|
||||||
|
complexity = best_complexity
|
||||||
|
print(
|
||||||
|
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||||
|
)
|
||||||
|
compact_evaluator.cleanup()
|
||||||
|
else:
|
||||||
|
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||||
|
print("All tested complexities were below target.")
|
||||||
|
|
||||||
|
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||||
|
binary_search_evaluator.cleanup()
|
||||||
|
evaluator.cleanup()
|
||||||
|
|
||||||
|
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||||
|
|
||||||
|
if args.stage == "4" or args.stage == "all":
|
||||||
|
# Stage 4: Index comparison (without LLM generation)
|
||||||
|
print("🚀 Starting Stage 4: Index comparison analysis")
|
||||||
|
|
||||||
|
# Use LAION evaluator for index comparison
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
|
||||||
|
# Load caption queries
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
|
||||||
|
# Step 1: Analyze current (compact) index
|
||||||
|
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||||
|
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||||
|
compact_size_metrics["index_type"] = "compact"
|
||||||
|
|
||||||
|
# Step 2: Use existing non-compact index or create if needed
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
print(
|
||||||
|
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||||
|
)
|
||||||
|
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||||
|
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||||
|
non_compact_size_metrics["index_type"] = "non_compact"
|
||||||
|
else:
|
||||||
|
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||||
|
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||||
|
non_compact_index_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Compare index sizes (.index only)
|
||||||
|
print("\n📊 Index size comparison (.index only):")
|
||||||
|
print(
|
||||||
|
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
|
||||||
|
)
|
||||||
|
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
|
||||||
|
|
||||||
|
storage_saving = 0.0
|
||||||
|
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
|
||||||
|
storage_saving = (
|
||||||
|
(
|
||||||
|
non_compact_size_metrics.get("index_only_mb", 0)
|
||||||
|
- compact_size_metrics.get("index_only_mb", 0)
|
||||||
|
)
|
||||||
|
/ non_compact_size_metrics.get("index_only_mb", 1)
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||||
|
|
||||||
|
# Step 4: Performance comparison between the two indexes
|
||||||
|
if complexity is None:
|
||||||
|
raise ValueError("Complexity is required for index comparison")
|
||||||
|
|
||||||
|
print("\n⚡ Performance comparison between indexes...")
|
||||||
|
performance_metrics = evaluator.compare_index_performance(
|
||||||
|
non_compact_index_path, args.index, captions[:10], complexity=complexity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine all metrics
|
||||||
|
combined_metrics = {
|
||||||
|
"current_index": compact_size_metrics,
|
||||||
|
"non_compact_index": non_compact_size_metrics,
|
||||||
|
"performance_comparison": performance_metrics,
|
||||||
|
"storage_saving_percent": storage_saving,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Print comprehensive results
|
||||||
|
evaluator._print_results(combined_metrics)
|
||||||
|
|
||||||
|
# Save results if requested
|
||||||
|
if args.output:
|
||||||
|
print(f"\n💾 Saving results to {args.output}...")
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(combined_metrics, f, indent=2, default=str)
|
||||||
|
print(f"✅ Results saved to {args.output}")
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 4 completed!\n")
|
||||||
|
|
||||||
|
if args.stage in ("5", "all"):
|
||||||
|
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
|
||||||
|
evaluator = LAIONEvaluator(args.index)
|
||||||
|
captions = evaluator.load_queries(args.queries)
|
||||||
|
test_captions = captions[: min(20, len(captions))] # Use subset for generation
|
||||||
|
|
||||||
|
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
|
||||||
|
|
||||||
|
# Load Qwen2.5-VL model
|
||||||
|
try:
|
||||||
|
print("Loading Qwen2.5-VL model...")
|
||||||
|
processor, model = load_qwen_vl_model(args.model_name)
|
||||||
|
|
||||||
|
# Run multimodal generation evaluation
|
||||||
|
complexity = args.complexity or 64
|
||||||
|
gen_results = evaluate_multimodal_rag(
|
||||||
|
evaluator.searcher,
|
||||||
|
test_captions,
|
||||||
|
processor=processor,
|
||||||
|
model=model,
|
||||||
|
complexity=complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n📊 Multimodal Generation Results:")
|
||||||
|
print(f" Total Queries: {len(test_captions)}")
|
||||||
|
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
|
||||||
|
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
|
||||||
|
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
|
||||||
|
search_pct = (gen_results["avg_search_time"] / total_time) * 100
|
||||||
|
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
|
||||||
|
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
|
||||||
|
print(" LLM Backend: HuggingFace transformers")
|
||||||
|
print(f" Model: {args.model_name}")
|
||||||
|
|
||||||
|
# Show sample results
|
||||||
|
print("\n📝 Sample Multimodal Generations:")
|
||||||
|
for i, response in enumerate(gen_results["results"][:3]):
|
||||||
|
# Handle both string and dict formats for captions
|
||||||
|
if isinstance(test_captions[i], dict):
|
||||||
|
caption_text = test_captions[i].get("query", str(test_captions[i]))
|
||||||
|
else:
|
||||||
|
caption_text = str(test_captions[i])
|
||||||
|
print(f" Query {i + 1}: {caption_text[:60]}...")
|
||||||
|
print(f" Response {i + 1}: {response[:100]}...")
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Multimodal generation evaluation failed: {e}")
|
||||||
|
print("💡 Make sure transformers and Qwen2.5-VL are installed")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
evaluator.cleanup()
|
||||||
|
print("✅ Stage 5 completed!\n")
|
||||||
|
|
||||||
|
if args.stage == "all":
|
||||||
|
print("🎉 All evaluation stages completed successfully!")
|
||||||
|
print("\n📋 Summary:")
|
||||||
|
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
|
||||||
|
print(" Stage 3: ✅ Optimal complexity found")
|
||||||
|
print(" Stage 4: ✅ Index comparison analysis completed")
|
||||||
|
print(" Stage 5: ✅ Multimodal generation evaluation completed")
|
||||||
|
print("\n🔧 Recommended next steps:")
|
||||||
|
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||||
|
print(" - Review index comparison for storage vs performance tradeoffs")
|
||||||
|
|
||||||
|
# Clean up non-compact index after all stages complete
|
||||||
|
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||||
|
if Path(non_compact_index_path).exists():
|
||||||
|
temp_index_dir = Path(non_compact_index_path).parent
|
||||||
|
temp_index_name = Path(non_compact_index_path).name
|
||||||
|
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||||
|
temp_file.unlink()
|
||||||
|
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print("📝 No temporary index to clean up")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Evaluation interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
576
benchmarks/laion/setup_laion.py
Normal file
576
benchmarks/laion/setup_laion.py
Normal file
@@ -0,0 +1,576 @@
|
|||||||
|
"""
|
||||||
|
LAION Multimodal Benchmark Setup Script
|
||||||
|
Downloads LAION subset and builds LEANN index with sentence embeddings
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from datasets import load_dataset
|
||||||
|
from leann import LeannBuilder
|
||||||
|
from PIL import Image
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class LAIONSetup:
|
||||||
|
def __init__(self, data_dir: str = "data"):
|
||||||
|
self.data_dir = Path(data_dir)
|
||||||
|
self.images_dir = self.data_dir / "laion_images"
|
||||||
|
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
|
||||||
|
|
||||||
|
# Create directories
|
||||||
|
self.data_dir.mkdir(exist_ok=True)
|
||||||
|
self.images_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
|
||||||
|
"""Download a single image asynchronously"""
|
||||||
|
async with semaphore: # Limit concurrent downloads
|
||||||
|
try:
|
||||||
|
image_url = sample_data["url"]
|
||||||
|
image_path = sample_data["image_path"]
|
||||||
|
|
||||||
|
# Skip if already exists
|
||||||
|
if os.path.exists(image_path):
|
||||||
|
progress_bar.update(1)
|
||||||
|
return sample_data
|
||||||
|
|
||||||
|
async with session.get(image_url, timeout=10) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
|
||||||
|
# Verify it's a valid image
|
||||||
|
try:
|
||||||
|
img = Image.open(io.BytesIO(content))
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.save(image_path, "JPEG")
|
||||||
|
progress_bar.update(1)
|
||||||
|
return sample_data
|
||||||
|
except Exception:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None # Skip invalid images
|
||||||
|
else:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
progress_bar.update(1)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def download_laion_subset(self, num_samples: int = 1000):
|
||||||
|
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
|
||||||
|
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
|
||||||
|
|
||||||
|
# Load LAION-400M subset from HuggingFace
|
||||||
|
print("🤗 Loading from HuggingFace datasets...")
|
||||||
|
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
|
||||||
|
|
||||||
|
# Collect sample metadata first (fast)
|
||||||
|
print("📋 Collecting sample metadata...")
|
||||||
|
candidates = []
|
||||||
|
for sample in dataset:
|
||||||
|
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
|
||||||
|
break
|
||||||
|
|
||||||
|
image_url = sample.get("url", "")
|
||||||
|
caption = sample.get("caption", "")
|
||||||
|
|
||||||
|
if not image_url or not caption:
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_filename = f"laion_{len(candidates):06d}.jpg"
|
||||||
|
image_path = self.images_dir / image_filename
|
||||||
|
|
||||||
|
candidate = {
|
||||||
|
"id": f"laion_{len(candidates):06d}",
|
||||||
|
"url": image_url,
|
||||||
|
"caption": caption,
|
||||||
|
"image_path": str(image_path),
|
||||||
|
"width": sample.get("original_width", 512),
|
||||||
|
"height": sample.get("original_height", 512),
|
||||||
|
"similarity": sample.get("similarity", 0.0),
|
||||||
|
}
|
||||||
|
candidates.append(candidate)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download images in parallel
|
||||||
|
async def download_batch():
|
||||||
|
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
|
||||||
|
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=30)
|
||||||
|
|
||||||
|
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||||
|
tasks = []
|
||||||
|
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
|
||||||
|
task = self.download_single_image(session, candidate, semaphore, progress_bar)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all downloads
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
progress_bar.close()
|
||||||
|
|
||||||
|
# Filter successful downloads
|
||||||
|
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
|
||||||
|
return successful[:num_samples]
|
||||||
|
|
||||||
|
# Run async download
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
samples = loop.run_until_complete(download_batch())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Save metadata
|
||||||
|
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
||||||
|
for sample in samples:
|
||||||
|
f.write(json.dumps(sample) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def generate_clip_image_embeddings(self, samples: list[dict]):
|
||||||
|
"""Generate CLIP image embeddings for downloaded images"""
|
||||||
|
print("🔍 Generating CLIP image embeddings...")
|
||||||
|
|
||||||
|
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
|
||||||
|
# This single model can encode both images and text.
|
||||||
|
model = SentenceTransformer("clip-ViT-L-14")
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
valid_samples = []
|
||||||
|
|
||||||
|
for sample in tqdm(samples, desc="Processing images"):
|
||||||
|
try:
|
||||||
|
# Load image
|
||||||
|
image_path = sample["image_path"]
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
|
||||||
|
# Encode image to 768-dim embedding via sentence-transformers (normalized)
|
||||||
|
vec = model.encode(
|
||||||
|
[image],
|
||||||
|
convert_to_numpy=True,
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=1,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)[0]
|
||||||
|
embeddings.append(vec.astype(np.float32))
|
||||||
|
valid_samples.append(sample)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠️ Failed to process {sample['id']}: {e}")
|
||||||
|
# Skip invalid images
|
||||||
|
|
||||||
|
embeddings = np.array(embeddings, dtype=np.float32)
|
||||||
|
|
||||||
|
# Save embeddings
|
||||||
|
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
np.save(embeddings_file, embeddings)
|
||||||
|
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
|
||||||
|
|
||||||
|
return embeddings, valid_samples
|
||||||
|
|
||||||
|
def build_faiss_baseline(
|
||||||
|
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
|
||||||
|
):
|
||||||
|
"""Build FAISS flat baseline using CLIP image embeddings"""
|
||||||
|
print("🔨 Building FAISS Flat baseline...")
|
||||||
|
|
||||||
|
from leann_backend_hnsw import faiss
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||||
|
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||||
|
|
||||||
|
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||||
|
print(f"✅ Baseline already exists at {baseline_path}")
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
# Extract image IDs (must be present)
|
||||||
|
if not samples or "id" not in samples[0]:
|
||||||
|
raise KeyError("samples missing 'id' field for FAISS baseline")
|
||||||
|
image_ids: list[str] = [str(sample["id"]) for sample in samples]
|
||||||
|
|
||||||
|
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||||
|
print(f"📄 Processing {len(image_ids)} images")
|
||||||
|
|
||||||
|
# Build FAISS flat index
|
||||||
|
print("🏗️ Building FAISS IndexFlatIP...")
|
||||||
|
dimension = embeddings.shape[1]
|
||||||
|
index = faiss.IndexFlatIP(dimension)
|
||||||
|
|
||||||
|
# Add embeddings to flat index
|
||||||
|
embeddings_f32 = embeddings.astype(np.float32)
|
||||||
|
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||||
|
|
||||||
|
# Save index and metadata
|
||||||
|
faiss.write_index(index, baseline_path)
|
||||||
|
with open(metadata_path, "wb") as f:
|
||||||
|
pickle.dump(image_ids, f)
|
||||||
|
|
||||||
|
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||||
|
print(f"✅ Metadata saved to {metadata_path}")
|
||||||
|
print(f"📊 Total vectors: {index.ntotal}")
|
||||||
|
|
||||||
|
return baseline_path
|
||||||
|
|
||||||
|
def create_leann_passages(self, samples: list[dict]):
|
||||||
|
"""Create LEANN-compatible passages from LAION data"""
|
||||||
|
print("📝 Creating LEANN passages...")
|
||||||
|
|
||||||
|
passages_file = self.data_dir / "laion_passages.jsonl"
|
||||||
|
|
||||||
|
with open(passages_file, "w", encoding="utf-8") as f:
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
passage = {
|
||||||
|
"id": sample["id"],
|
||||||
|
"text": sample["caption"], # Use caption as searchable text
|
||||||
|
"metadata": {
|
||||||
|
"image_url": sample["url"],
|
||||||
|
"image_path": sample.get("image_path", ""),
|
||||||
|
"width": sample["width"],
|
||||||
|
"height": sample["height"],
|
||||||
|
"similarity": sample["similarity"],
|
||||||
|
"image_index": i, # Index for embedding lookup
|
||||||
|
},
|
||||||
|
}
|
||||||
|
f.write(json.dumps(passage) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Created {len(samples)} passages")
|
||||||
|
return passages_file
|
||||||
|
|
||||||
|
def build_compact_index(
|
||||||
|
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
|
||||||
|
print(f"🏗️ Building compact LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
|
||||||
|
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
np.save(npy_path, embeddings)
|
||||||
|
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||||
|
|
||||||
|
# Prepare ids in the same order as passages_file (matches embeddings order)
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
rec = json.loads(line)
|
||||||
|
ids.append(str(rec["id"]))
|
||||||
|
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
|
||||||
|
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||||
|
|
||||||
|
# Initialize builder - compact with recompute
|
||||||
|
# Note: For multimodal case, we need to handle embeddings differently
|
||||||
|
# Let's try using sentence-transformers mode but with custom embeddings
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
# HNSW params (or forwarded to chosen backend)
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
# Compact/pruned with recompute at query time
|
||||||
|
is_recompute=True,
|
||||||
|
is_compact=True,
|
||||||
|
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add passages (text + metadata)
|
||||||
|
print("📚 Adding passages...")
|
||||||
|
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||||
|
|
||||||
|
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
|
||||||
|
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
print(f"✅ Compact index built in {build_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze index size
|
||||||
|
self._analyze_index_size(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def build_non_compact_index(
|
||||||
|
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||||
|
):
|
||||||
|
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
|
||||||
|
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Ensure embeddings are saved (npy + pickle)
|
||||||
|
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||||
|
if not npy_path.exists():
|
||||||
|
np.save(npy_path, embeddings)
|
||||||
|
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||||
|
# Prepare ids in same order as passages_file
|
||||||
|
ids: list[str] = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
rec = json.loads(line)
|
||||||
|
ids.append(str(rec["id"]))
|
||||||
|
if len(ids) != embeddings.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||||
|
)
|
||||||
|
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||||
|
if not pkl_path.exists():
|
||||||
|
with open(pkl_path, "wb") as pf:
|
||||||
|
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||||
|
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||||
|
|
||||||
|
# Initialize builder - non-compact without recompute
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name=backend,
|
||||||
|
embedding_model="clip-ViT-L-14",
|
||||||
|
embedding_mode="sentence-transformers",
|
||||||
|
graph_degree=32,
|
||||||
|
complexity=64,
|
||||||
|
is_recompute=False, # Store embeddings (no recompute needed)
|
||||||
|
is_compact=False, # Store full index (not pruned)
|
||||||
|
distance_metric="cosine",
|
||||||
|
num_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add passages - embeddings will be loaded from file
|
||||||
|
print("📚 Adding passages...")
|
||||||
|
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||||
|
|
||||||
|
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
|
||||||
|
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||||
|
|
||||||
|
build_time = time.time() - start_time
|
||||||
|
print(f"✅ Non-compact index built in {build_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze index size
|
||||||
|
self._analyze_index_size(index_path)
|
||||||
|
|
||||||
|
return index_path
|
||||||
|
|
||||||
|
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
|
||||||
|
"""Helper to add passages with pre-computed CLIP embeddings"""
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in tqdm(f, desc="Adding passages"):
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
|
||||||
|
# Add image metadata - LEANN will handle embeddings separately
|
||||||
|
# Note: We store image metadata and caption text for searchability
|
||||||
|
# Important: ensure passage ID in metadata matches vector ID
|
||||||
|
builder.add_text(
|
||||||
|
text=passage["text"], # Image caption for searchability
|
||||||
|
metadata={**passage["metadata"], "id": passage["id"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _analyze_index_size(self, index_path: str):
|
||||||
|
"""Analyze index file sizes"""
|
||||||
|
print("📏 Analyzing index sizes...")
|
||||||
|
|
||||||
|
index_path = Path(index_path)
|
||||||
|
index_dir = index_path.parent
|
||||||
|
index_name = index_path.name # e.g., laion_index.leann
|
||||||
|
index_prefix = index_path.stem # e.g., laion_index
|
||||||
|
|
||||||
|
files = [
|
||||||
|
(f"{index_prefix}.index", ".index", "core"),
|
||||||
|
(f"{index_name}.meta.json", ".meta.json", "core"),
|
||||||
|
(f"{index_name}.ids.txt", ".ids.txt", "core"),
|
||||||
|
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
|
||||||
|
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _fmt_size(bytes_val: int) -> str:
|
||||||
|
if bytes_val < 1024:
|
||||||
|
return f"{bytes_val} B"
|
||||||
|
kb = bytes_val / 1024
|
||||||
|
if kb < 1024:
|
||||||
|
return f"{kb:.1f} KB"
|
||||||
|
mb = kb / 1024
|
||||||
|
if mb < 1024:
|
||||||
|
return f"{mb:.2f} MB"
|
||||||
|
gb = mb / 1024
|
||||||
|
return f"{gb:.2f} GB"
|
||||||
|
|
||||||
|
total_index_only_mb = 0.0
|
||||||
|
total_all_mb = 0.0
|
||||||
|
for filename, label, group in files:
|
||||||
|
file_path = index_dir / filename
|
||||||
|
if file_path.exists():
|
||||||
|
size_bytes = file_path.stat().st_size
|
||||||
|
print(f" {label}: {_fmt_size(size_bytes)}")
|
||||||
|
size_mb = size_bytes / (1024 * 1024)
|
||||||
|
total_all_mb += size_mb
|
||||||
|
if group == "core":
|
||||||
|
total_index_only_mb += size_mb
|
||||||
|
else:
|
||||||
|
print(f" {label}: (missing)")
|
||||||
|
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
|
||||||
|
print(f" Total (including passages): {total_all_mb:.2f} MB")
|
||||||
|
|
||||||
|
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
|
||||||
|
"""Create evaluation queries from captions"""
|
||||||
|
print(f"📝 Creating {num_queries} evaluation queries...")
|
||||||
|
|
||||||
|
# Sample random captions as queries
|
||||||
|
import random
|
||||||
|
|
||||||
|
random.seed(42) # For reproducibility
|
||||||
|
|
||||||
|
query_samples = random.sample(samples, min(num_queries, len(samples)))
|
||||||
|
|
||||||
|
queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||||
|
with open(queries_file, "w", encoding="utf-8") as f:
|
||||||
|
for sample in query_samples:
|
||||||
|
query = {
|
||||||
|
"id": sample["id"],
|
||||||
|
"query": sample["caption"],
|
||||||
|
"ground_truth_id": sample["id"], # For potential recall evaluation
|
||||||
|
}
|
||||||
|
f.write(json.dumps(query) + "\n")
|
||||||
|
|
||||||
|
print(f"✅ Created {len(query_samples)} evaluation queries")
|
||||||
|
return queries_file
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
|
||||||
|
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||||
|
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
|
||||||
|
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
|
||||||
|
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
|
||||||
|
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("🚀 Setting up LAION Multimodal Benchmark")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize setup
|
||||||
|
setup = LAIONSetup(args.data_dir)
|
||||||
|
|
||||||
|
# Step 1: Download LAION subset
|
||||||
|
if not args.skip_download:
|
||||||
|
print("\n📦 Step 1: Download LAION subset")
|
||||||
|
samples = setup.download_laion_subset(args.num_samples)
|
||||||
|
|
||||||
|
# Step 2: Generate CLIP image embeddings
|
||||||
|
print("\n🔍 Step 2: Generate CLIP image embeddings")
|
||||||
|
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
|
||||||
|
|
||||||
|
# Step 3: Create LEANN passages (image metadata with embeddings)
|
||||||
|
print("\n📝 Step 3: Create LEANN passages")
|
||||||
|
passages_file = setup.create_leann_passages(valid_samples)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping LAION dataset download")
|
||||||
|
# Load existing data
|
||||||
|
passages_file = setup.data_dir / "laion_passages.jsonl"
|
||||||
|
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not embeddings_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passages or embeddings file not found. Run without --skip-download first."
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.load(embeddings_file)
|
||||||
|
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
|
||||||
|
|
||||||
|
# Step 4: Build LEANN indexes (both compact and non-compact)
|
||||||
|
if not args.skip_build:
|
||||||
|
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
|
||||||
|
|
||||||
|
# Build compact index (production mode - small, recompute required)
|
||||||
|
compact_index_path = args.index_path
|
||||||
|
print(f"Building compact index: {compact_index_path}")
|
||||||
|
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
|
||||||
|
|
||||||
|
# Build non-compact index (comparison mode - large, fast search)
|
||||||
|
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
|
||||||
|
print(f"Building non-compact index: {non_compact_index_path}")
|
||||||
|
setup.build_non_compact_index(
|
||||||
|
passages_file, embeddings, non_compact_index_path, args.backend
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Build FAISS flat baseline
|
||||||
|
print("\n🔨 Step 5: Build FAISS flat baseline")
|
||||||
|
if not args.skip_download:
|
||||||
|
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||||
|
else:
|
||||||
|
# Load valid_samples from passages file for FAISS baseline
|
||||||
|
valid_samples = []
|
||||||
|
with open(passages_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.strip():
|
||||||
|
passage = json.loads(line)
|
||||||
|
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
|
||||||
|
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||||
|
|
||||||
|
# Step 6: Create evaluation queries
|
||||||
|
print("\n📝 Step 6: Create evaluation queries")
|
||||||
|
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
|
||||||
|
else:
|
||||||
|
print("⏭️ Skipping index building")
|
||||||
|
baseline_path = "data/baseline/faiss_index.bin"
|
||||||
|
queries_file = setup.data_dir / "evaluation_queries.jsonl"
|
||||||
|
|
||||||
|
print("\n🎉 Setup completed successfully!")
|
||||||
|
print("📊 Summary:")
|
||||||
|
if not args.skip_download:
|
||||||
|
print(f" Downloaded samples: {len(samples)}")
|
||||||
|
print(f" Valid samples with embeddings: {len(valid_samples)}")
|
||||||
|
else:
|
||||||
|
print(f" Loaded {len(embeddings)} embeddings")
|
||||||
|
|
||||||
|
if not args.skip_build:
|
||||||
|
print(f" Compact index: {compact_index_path}")
|
||||||
|
print(f" Non-compact index: {non_compact_index_path}")
|
||||||
|
print(f" FAISS baseline: {baseline_path}")
|
||||||
|
print(f" Queries: {queries_file}")
|
||||||
|
|
||||||
|
print("\n🔧 Next steps:")
|
||||||
|
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
|
||||||
|
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
|
||||||
|
else:
|
||||||
|
print(" Skipped building indexes")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n⚠️ Setup interrupted by user")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ Setup failed: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
342
benchmarks/llm_utils.py
Normal file
342
benchmarks/llm_utils.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
"""
|
||||||
|
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
HF_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
HF_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
VLLM_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
VLLM_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_qwen3_model(model_name):
|
||||||
|
"""Check if model is Qwen3"""
|
||||||
|
return "Qwen3" in model_name or "qwen3" in model_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def is_qwen_vl_model(model_name):
|
||||||
|
"""Check if model is Qwen2.5-VL"""
|
||||||
|
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_qwen3_chat_template(tokenizer, prompt):
|
||||||
|
"""Apply Qwen3 chat template with thinking enabled"""
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_thinking_answer(response):
|
||||||
|
"""Extract final answer from Qwen3 thinking model response"""
|
||||||
|
if "<think>" in response and "</think>" in response:
|
||||||
|
try:
|
||||||
|
think_end = response.index("</think>") + len("</think>")
|
||||||
|
final_answer = response[think_end:].strip()
|
||||||
|
return final_answer
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||||
|
"""Load HuggingFace model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): Name of the model to load
|
||||||
|
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||||
|
Defaults to False for security. Only enable for trusted models.
|
||||||
|
"""
|
||||||
|
if not HF_AVAILABLE:
|
||||||
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
|
if trust_remote_code:
|
||||||
|
print(
|
||||||
|
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Loading HF: {model_name}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
|
||||||
|
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||||
|
"""Load vLLM model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): Name of the model to load
|
||||||
|
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||||
|
Defaults to False for security. Only enable for trusted models.
|
||||||
|
"""
|
||||||
|
if not VLLM_AVAILABLE:
|
||||||
|
raise ImportError("vllm not available")
|
||||||
|
|
||||||
|
if trust_remote_code:
|
||||||
|
print(
|
||||||
|
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Loading vLLM: {model_name}")
|
||||||
|
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
|
||||||
|
|
||||||
|
# Qwen3 specific config
|
||||||
|
if is_qwen3_model(model_name):
|
||||||
|
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
|
||||||
|
max_tokens = 2048
|
||||||
|
else:
|
||||||
|
stop_tokens = None
|
||||||
|
max_tokens = 1024
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
|
||||||
|
return llm, sampling_params
|
||||||
|
|
||||||
|
|
||||||
|
def generate_hf(tokenizer, model, prompt, max_tokens=None):
|
||||||
|
"""Generate with HF - supports Qwen3 thinking models"""
|
||||||
|
model_name = getattr(model, "name_or_path", "unknown")
|
||||||
|
is_qwen3 = is_qwen3_model(model_name)
|
||||||
|
|
||||||
|
# Apply chat template for Qwen3
|
||||||
|
if is_qwen3:
|
||||||
|
prompt = apply_qwen3_chat_template(tokenizer, prompt)
|
||||||
|
max_tokens = max_tokens or 2048
|
||||||
|
else:
|
||||||
|
max_tokens = max_tokens or 1024
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
temperature=0.7,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
response = response[len(prompt) :].strip()
|
||||||
|
|
||||||
|
# Extract final answer for thinking models
|
||||||
|
if is_qwen3:
|
||||||
|
return extract_thinking_answer(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def generate_vllm(llm, sampling_params, prompt):
|
||||||
|
"""Generate with vLLM - supports Qwen3 thinking models"""
|
||||||
|
outputs = llm.generate([prompt], sampling_params)
|
||||||
|
response = outputs[0].outputs[0].text.strip()
|
||||||
|
|
||||||
|
# Extract final answer for Qwen3 thinking models
|
||||||
|
model_name = str(llm.llm_engine.model_config.model)
|
||||||
|
if is_qwen3_model(model_name):
|
||||||
|
return extract_thinking_answer(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_prompt(context, query, domain="default"):
|
||||||
|
"""Create RAG prompt"""
|
||||||
|
if domain == "emails":
|
||||||
|
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
elif domain == "finance":
|
||||||
|
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
elif domain == "multimodal":
|
||||||
|
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
else:
|
||||||
|
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
|
||||||
|
"""Simple RAG evaluation with timing"""
|
||||||
|
search_times = []
|
||||||
|
gen_times = []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, query in enumerate(queries):
|
||||||
|
# Search
|
||||||
|
start = time.time()
|
||||||
|
docs = searcher.search(query, top_k=top_k, complexity=complexity)
|
||||||
|
search_time = time.time() - start
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
context = "\n\n".join([doc.text for doc in docs])
|
||||||
|
prompt = create_prompt(context, query, domain)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
response = llm_func(prompt)
|
||||||
|
gen_time = time.time() - start
|
||||||
|
|
||||||
|
search_times.append(search_time)
|
||||||
|
gen_times.append(gen_time)
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times),
|
||||||
|
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
|
||||||
|
"""Load Qwen2.5-VL multimodal model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): Name of the model to load
|
||||||
|
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||||
|
Defaults to False for security. Only enable for trusted models.
|
||||||
|
"""
|
||||||
|
if not HF_AVAILABLE:
|
||||||
|
raise ImportError("transformers not available")
|
||||||
|
|
||||||
|
if trust_remote_code:
|
||||||
|
print(
|
||||||
|
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
|
||||||
|
|
||||||
|
# Fallback to specific class
|
||||||
|
try:
|
||||||
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
model_name, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
return processor, model
|
||||||
|
|
||||||
|
except Exception as e2:
|
||||||
|
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
|
||||||
|
"""Generate with Qwen2.5-VL multimodal model"""
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Prepare inputs
|
||||||
|
if image_path:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||||||
|
else:
|
||||||
|
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
with torch.no_grad():
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode response
|
||||||
|
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
|
||||||
|
response = processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
|
||||||
|
"""Create prompt for multimodal RAG"""
|
||||||
|
if task_type == "images":
|
||||||
|
return f"""Based on the retrieved images and their descriptions, answer the following question.
|
||||||
|
|
||||||
|
Retrieved Image Descriptions:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
Question: {query}
|
||||||
|
|
||||||
|
Provide a detailed answer based on the visual content described above."""
|
||||||
|
|
||||||
|
return f"Context: {context}\nQuestion: {query}\nAnswer:"
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
|
||||||
|
"""Evaluate multimodal RAG with Qwen2.5-VL"""
|
||||||
|
search_times = []
|
||||||
|
gen_times = []
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, query_item in enumerate(queries):
|
||||||
|
# Handle both string and dict formats for queries
|
||||||
|
if isinstance(query_item, dict):
|
||||||
|
query = query_item.get("query", "")
|
||||||
|
image_path = query_item.get("image_path") # Optional reference image
|
||||||
|
else:
|
||||||
|
query = str(query_item)
|
||||||
|
image_path = None
|
||||||
|
|
||||||
|
# Search
|
||||||
|
start_time = time.time()
|
||||||
|
search_results = searcher.search(query, top_k=3, complexity=complexity)
|
||||||
|
search_time = time.time() - start_time
|
||||||
|
search_times.append(search_time)
|
||||||
|
|
||||||
|
# Prepare context from search results
|
||||||
|
context_parts = []
|
||||||
|
for result in search_results:
|
||||||
|
context_parts.append(f"- {result.text}")
|
||||||
|
context = "\n".join(context_parts)
|
||||||
|
|
||||||
|
# Generate with multimodal model
|
||||||
|
start_time = time.time()
|
||||||
|
if processor and model:
|
||||||
|
prompt = create_multimodal_prompt(context, query, context_parts)
|
||||||
|
response = generate_qwen_vl(processor, model, prompt, image_path)
|
||||||
|
else:
|
||||||
|
response = f"Context: {context}"
|
||||||
|
gen_time = time.time() - start_time
|
||||||
|
|
||||||
|
gen_times.append(gen_time)
|
||||||
|
results.append(response)
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"avg_search_time": sum(search_times) / len(search_times),
|
||||||
|
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
@@ -2,20 +2,20 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from contextlib import contextmanager
|
from transformers import AutoModel, BitsAndBytesConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str
|
model_path: str
|
||||||
batch_sizes: List[int]
|
batch_sizes: list[int]
|
||||||
seq_length: int
|
seq_length: int
|
||||||
num_runs: int
|
num_runs: int
|
||||||
use_fp16: bool = True
|
use_fp16: bool = True
|
||||||
@@ -32,13 +32,11 @@ class GraphContainer:
|
|||||||
def __init__(self, model: nn.Module, seq_length: int):
|
def __init__(self, model: nn.Module, seq_length: int):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.seq_length = seq_length
|
self.seq_length = seq_length
|
||||||
self.graphs: Dict[int, 'GraphWrapper'] = {}
|
self.graphs: dict[int, GraphWrapper] = {}
|
||||||
|
|
||||||
def get_or_create(self, batch_size: int) -> 'GraphWrapper':
|
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||||
if batch_size not in self.graphs:
|
if batch_size not in self.graphs:
|
||||||
self.graphs[batch_size] = GraphWrapper(
|
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||||
self.model, batch_size, self.seq_length
|
|
||||||
)
|
|
||||||
return self.graphs[batch_size]
|
return self.graphs[batch_size]
|
||||||
|
|
||||||
|
|
||||||
@@ -55,13 +53,13 @@ class GraphWrapper:
|
|||||||
self._warmup()
|
self._warmup()
|
||||||
|
|
||||||
# Only use CUDA graphs on NVIDIA GPUs
|
# Only use CUDA graphs on NVIDIA GPUs
|
||||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'):
|
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||||
# Capture graph
|
# Capture graph
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self.graph):
|
with torch.cuda.graph(self.graph):
|
||||||
self.static_output = self.model(
|
self.static_output = self.model(
|
||||||
input_ids=self.static_input,
|
input_ids=self.static_input,
|
||||||
attention_mask=self.static_attention_mask
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
self.use_cuda_graph = True
|
self.use_cuda_graph = True
|
||||||
else:
|
else:
|
||||||
@@ -79,9 +77,7 @@ class GraphWrapper:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000, (batch_size, seq_length),
|
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||||
device=self.device,
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _warmup(self, num_warmup: int = 3):
|
def _warmup(self, num_warmup: int = 3):
|
||||||
@@ -89,7 +85,7 @@ class GraphWrapper:
|
|||||||
for _ in range(num_warmup):
|
for _ in range(num_warmup):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=self.static_input,
|
input_ids=self.static_input,
|
||||||
attention_mask=self.static_attention_mask
|
attention_mask=self.static_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -133,8 +129,12 @@ class ModelOptimizer:
|
|||||||
print("- Using FP16 precision")
|
print("- Using FP16 precision")
|
||||||
|
|
||||||
# Check if using SDPA (only on CUDA)
|
# Check if using SDPA (only on CUDA)
|
||||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
@@ -142,7 +142,8 @@ class ModelOptimizer:
|
|||||||
# Flash Attention (only on CUDA)
|
# Flash Attention (only on CUDA)
|
||||||
if config.use_flash_attention and torch.cuda.is_available():
|
if config.use_flash_attention and torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attention import FlashAttention
|
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||||
|
|
||||||
print("- Flash Attention 2 available")
|
print("- Flash Attention 2 available")
|
||||||
if hasattr(model.config, "attention_mode"):
|
if hasattr(model.config, "attention_mode"):
|
||||||
model.config.attention_mode = "flash_attention_2"
|
model.config.attention_mode = "flash_attention_2"
|
||||||
@@ -153,8 +154,9 @@ class ModelOptimizer:
|
|||||||
# Memory efficient attention (only on CUDA)
|
# Memory efficient attention (only on CUDA)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from xformers.ops import memory_efficient_attention
|
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
else:
|
else:
|
||||||
@@ -220,7 +222,7 @@ class Benchmark:
|
|||||||
self.graphs = None
|
self.graphs = None
|
||||||
self.timer = Timer()
|
self.timer = Timer()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
print(f"ERROR in benchmark initialization: {e!s}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
def _load_model(self) -> nn.Module:
|
||||||
@@ -230,15 +232,17 @@ class Benchmark:
|
|||||||
# Int4 quantization using HuggingFace integration
|
# Int4 quantization using HuggingFace integration
|
||||||
if self.config.use_int4:
|
if self.config.use_int4:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||||
|
|
||||||
# 检查是否使用自定义的8bit量化
|
# Check if using custom 8bit quantization
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||||
|
|
||||||
# 加载原始模型(不使用量化配置)
|
# Load original model (without quantization config)
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# set default to half
|
# set default to half
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||||
@@ -247,52 +251,58 @@ class Benchmark:
|
|||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 定义替换函数
|
# Define replacement function
|
||||||
def replace_linear_with_linear8bitlt(model):
|
def replace_linear_with_linear8bitlt(model):
|
||||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||||
for name, module in list(model.named_children()):
|
for name, module in list(model.named_children()):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
# 获取原始线性层的参数
|
# Get original linear layer parameters
|
||||||
in_features = module.in_features
|
in_features = module.in_features
|
||||||
out_features = module.out_features
|
out_features = module.out_features
|
||||||
bias = module.bias is not None
|
bias = module.bias is not None
|
||||||
|
|
||||||
# 创建8bit线性层
|
# Create 8bit linear layer
|
||||||
# print size
|
# print size
|
||||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||||
new_module = bnb.nn.Linear8bitLt(
|
new_module = bnb.nn.Linear8bitLt(
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
has_fp16_weights=False
|
has_fp16_weights=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 复制权重和偏置
|
# Copy weights and bias
|
||||||
new_module.weight.data = module.weight.data
|
new_module.weight.data = module.weight.data
|
||||||
if bias:
|
if bias:
|
||||||
new_module.bias.data = module.bias.data
|
new_module.bias.data = module.bias.data
|
||||||
|
|
||||||
# 替换模块
|
# Replace module
|
||||||
setattr(model, name, new_module)
|
setattr(model, name, new_module)
|
||||||
else:
|
else:
|
||||||
# 递归处理子模块
|
# Process child modules recursively
|
||||||
replace_linear_with_linear8bitlt(module)
|
replace_linear_with_linear8bitlt(module)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
# 替换所有线性层
|
# Replace all linear layers
|
||||||
model = replace_linear_with_linear8bitlt(model)
|
model = replace_linear_with_linear8bitlt(model)
|
||||||
# add torch compile
|
# add torch compile
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
# 将模型移到GPU(量化发生在这里)
|
# Move model to GPU (quantization happens here)
|
||||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print("- All linear layers replaced with Linear8bitLt")
|
print("- All linear layers replaced with Linear8bitLt")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 使用原来的Int4量化方法
|
# Use original Int4 quantization method
|
||||||
print("- Using bitsandbytes for Int4 quantization")
|
print("- Using bitsandbytes for Int4 quantization")
|
||||||
|
|
||||||
# Create quantization config
|
# Create quantization config
|
||||||
@@ -302,7 +312,7 @@ class Benchmark:
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
bnb_4bit_compute_dtype=compute_dtype,
|
bnb_4bit_compute_dtype=compute_dtype,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4"
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
|
|
||||||
print("- Quantization config:", quantization_config)
|
print("- Quantization config:", quantization_config)
|
||||||
@@ -312,7 +322,7 @@ class Benchmark:
|
|||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map="auto" # Let HF decide on device mapping
|
device_map="auto", # Let HF decide on device mapping
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if model loaded successfully
|
# Check if model loaded successfully
|
||||||
@@ -324,7 +334,7 @@ class Benchmark:
|
|||||||
# Apply optimizations directly here
|
# Apply optimizations directly here
|
||||||
print("\nApplying model optimizations:")
|
print("\nApplying model optimizations:")
|
||||||
|
|
||||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||||
else:
|
else:
|
||||||
# Skip moving to GPU since device_map="auto" already did that
|
# Skip moving to GPU since device_map="auto" already did that
|
||||||
@@ -334,8 +344,12 @@ class Benchmark:
|
|||||||
print(f"- Using {compute_dtype} for compute dtype")
|
print(f"- Using {compute_dtype} for compute dtype")
|
||||||
|
|
||||||
# Check CUDA and SDPA
|
# Check CUDA and SDPA
|
||||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
if (
|
||||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
torch.cuda.is_available()
|
||||||
|
and torch.version.cuda
|
||||||
|
and float(torch.version.cuda[:3]) >= 11.6
|
||||||
|
):
|
||||||
|
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||||
else:
|
else:
|
||||||
print("- PyTorch SDPA not available")
|
print("- PyTorch SDPA not available")
|
||||||
@@ -343,8 +357,7 @@ class Benchmark:
|
|||||||
# Try xformers if available (only on CUDA)
|
# Try xformers if available (only on CUDA)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
try:
|
try:
|
||||||
from xformers.ops import memory_efficient_attention
|
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
|
||||||
model.enable_xformers_memory_efficient_attention()
|
model.enable_xformers_memory_efficient_attention()
|
||||||
print("- Enabled xformers memory efficient attention")
|
print("- Enabled xformers memory efficient attention")
|
||||||
else:
|
else:
|
||||||
@@ -370,7 +383,7 @@ class Benchmark:
|
|||||||
self.config.model_path,
|
self.config.model_path,
|
||||||
quantization_config=quantization_config,
|
quantization_config=quantization_config,
|
||||||
torch_dtype=compute_dtype,
|
torch_dtype=compute_dtype,
|
||||||
device_map="auto"
|
device_map="auto",
|
||||||
)
|
)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
@@ -389,6 +402,7 @@ class Benchmark:
|
|||||||
# Apply standard optimizations
|
# Apply standard optimizations
|
||||||
# set default to half
|
# set default to half
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
model = ModelOptimizer.optimize(model, self.config)
|
model = ModelOptimizer.optimize(model, self.config)
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@@ -403,25 +417,31 @@ class Benchmark:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR loading model: {str(e)}")
|
print(f"ERROR loading model: {e!s}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000,
|
0,
|
||||||
|
1000,
|
||||||
(batch_size, self.config.seq_length),
|
(batch_size, self.config.seq_length),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(
|
def _run_inference(
|
||||||
self,
|
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||||
input_ids: torch.Tensor,
|
) -> tuple[float, torch.Tensor]:
|
||||||
graph_wrapper: Optional[GraphWrapper] = None
|
|
||||||
) -> Tuple[float, torch.Tensor]:
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
|
||||||
with torch.no_grad(), self.timer.timing():
|
with torch.no_grad(), self.timer.timing():
|
||||||
@@ -432,7 +452,7 @@ class Benchmark:
|
|||||||
|
|
||||||
return self.timer.elapsed_time(), output
|
return self.timer.elapsed_time(), output
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
# Reset peak memory stats
|
# Reset peak memory stats
|
||||||
@@ -450,9 +470,7 @@ class Benchmark:
|
|||||||
|
|
||||||
# Get or create graph for this batch size
|
# Get or create graph for this batch size
|
||||||
graph_wrapper = (
|
graph_wrapper = (
|
||||||
self.graphs.get_or_create(batch_size)
|
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||||
if self.graphs is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pre-allocate input tensor
|
# Pre-allocate input tensor
|
||||||
@@ -490,7 +508,7 @@ class Benchmark:
|
|||||||
|
|
||||||
# Log memory usage
|
# Log memory usage
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
# MPS doesn't have max_memory_allocated, use 0
|
# MPS doesn't have max_memory_allocated, use 0
|
||||||
peak_memory_gb = 0.0
|
peak_memory_gb = 0.0
|
||||||
@@ -604,7 +622,15 @@ def main():
|
|||||||
os.makedirs("results", exist_ok=True)
|
os.makedirs("results", exist_ok=True)
|
||||||
|
|
||||||
# Generate filename based on configuration
|
# Generate filename based on configuration
|
||||||
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32"
|
precision_type = (
|
||||||
|
"int4"
|
||||||
|
if config.use_int4
|
||||||
|
else "int8"
|
||||||
|
if config.use_int8
|
||||||
|
else "fp16"
|
||||||
|
if config.use_fp16
|
||||||
|
else "fp32"
|
||||||
|
)
|
||||||
model_name = os.path.basename(config.model_path)
|
model_name = os.path.basename(config.model_path)
|
||||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||||
|
|
||||||
@@ -612,17 +638,20 @@ def main():
|
|||||||
with open(output_file, "w") as f:
|
with open(output_file, "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
"config": {
|
||||||
"results": {str(k): v for k, v in results.items()}
|
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||||
|
},
|
||||||
|
"results": {str(k): v for k, v in results.items()},
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
indent=2
|
indent=2,
|
||||||
)
|
)
|
||||||
print(f"Results saved to {output_file}")
|
print(f"Results saved to {output_file}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Benchmark failed: {e}")
|
print(f"Benchmark failed: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
@@ -5,24 +5,21 @@ It correctly compares results by fetching the text content for both the new sear
|
|||||||
results and the golden standard results, making the comparison robust to ID changes.
|
results and the golden standard results, making the comparison robust to ID changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from leann.api import LeannSearcher, LeannBuilder
|
import numpy as np
|
||||||
|
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||||
|
|
||||||
|
|
||||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||||
if not data_root.exists():
|
if not data_root.exists():
|
||||||
print(f"Data directory '{data_root}' not found.")
|
print(f"Data directory '{data_root}' not found.")
|
||||||
print(
|
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
@@ -56,14 +53,14 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
|||||||
print(
|
print(
|
||||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||||
)
|
)
|
||||||
print("uv pip install -e '.[dev]'")
|
print("uv sync --only-group dev")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred during data download: {e}")
|
print(f"An error occurred during data download: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||||
"""Download embeddings files specifically."""
|
"""Download embeddings files specifically."""
|
||||||
embeddings_dir = data_root / "embeddings"
|
embeddings_dir = data_root / "embeddings"
|
||||||
|
|
||||||
@@ -101,7 +98,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
|||||||
|
|
||||||
|
|
||||||
# --- Helper Function to get Golden Passages ---
|
# --- Helper Function to get Golden Passages ---
|
||||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||||
"""
|
"""
|
||||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||||
passage manager.
|
passage manager.
|
||||||
@@ -113,24 +110,20 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
|||||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||||
golden_texts.add(passage_data["text"])
|
golden_texts.add(passage_data["text"])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(
|
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
|
||||||
)
|
|
||||||
return golden_texts
|
return golden_texts
|
||||||
|
|
||||||
|
|
||||||
def load_queries(file_path: Path) -> List[str]:
|
def load_queries(file_path: Path) -> list[str]:
|
||||||
queries = []
|
queries = []
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
data = json.loads(line)
|
data = json.loads(line)
|
||||||
queries.append(data["query"])
|
queries.append(data["query"])
|
||||||
return queries
|
return queries
|
||||||
|
|
||||||
|
|
||||||
def build_index_from_embeddings(
|
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Build a LEANN index from pre-computed embeddings.
|
Build a LEANN index from pre-computed embeddings.
|
||||||
|
|
||||||
@@ -173,9 +166,7 @@ def build_index_from_embeddings(
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||||
description="Run recall evaluation on a LEANN index."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"index_path",
|
"index_path",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -202,26 +193,41 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Batch size for HNSW batched search (0 disables batching)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-type",
|
||||||
|
type=str,
|
||||||
|
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||||
|
default="ollama",
|
||||||
|
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm-model",
|
||||||
|
type=str,
|
||||||
|
default="qwen3:1.7b",
|
||||||
|
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# --- Path Configuration ---
|
# --- Path Configuration ---
|
||||||
# Assumes a project structure where the script is in 'examples/'
|
# Assumes a project structure where the script is in 'benchmarks/'
|
||||||
# and data is in 'data/' at the project root.
|
# and evaluation data is in 'benchmarks/data/'.
|
||||||
project_root = Path(__file__).resolve().parent.parent
|
script_dir = Path(__file__).resolve().parent
|
||||||
data_root = project_root / "data"
|
data_root = script_dir / "data"
|
||||||
|
|
||||||
# Download data based on mode
|
# Download data based on mode
|
||||||
if args.mode == "build":
|
if args.mode == "build":
|
||||||
# For building mode, we need embeddings
|
# For building mode, we need embeddings
|
||||||
download_data_if_needed(
|
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||||
data_root, download_embeddings=False
|
|
||||||
) # Basic data first
|
|
||||||
|
|
||||||
# Auto-detect dataset type and download embeddings
|
# Auto-detect dataset type and download embeddings
|
||||||
if args.embeddings_file:
|
if args.embeddings_file:
|
||||||
@@ -262,9 +268,7 @@ def main():
|
|||||||
print(f"Index built successfully: {built_index_path}")
|
print(f"Index built successfully: {built_index_path}")
|
||||||
|
|
||||||
# Ask if user wants to run evaluation
|
# Ask if user wants to run evaluation
|
||||||
eval_response = (
|
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
|
||||||
)
|
|
||||||
if eval_response != "y":
|
if eval_response != "y":
|
||||||
print("Index building complete. Exiting.")
|
print("Index building complete. Exiting.")
|
||||||
return
|
return
|
||||||
@@ -293,11 +297,9 @@ def main():
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not args.index_path:
|
if not args.index_path:
|
||||||
|
print("No indices found. The data download should have included pre-built indices.")
|
||||||
print(
|
print(
|
||||||
"No indices found. The data download should have included pre-built indices."
|
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||||
)
|
|
||||||
print(
|
|
||||||
"Please check the data/indices/ directory or provide --index-path manually."
|
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@@ -310,14 +312,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
# Fallback: try to infer from the index directory name
|
# Fallback: try to infer from the index directory name
|
||||||
dataset_type = Path(args.index_path).name
|
dataset_type = Path(args.index_path).name
|
||||||
print(
|
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||||
golden_results_file = (
|
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||||
print(f"INFO: Using queries file: {queries_file}")
|
print(f"INFO: Using queries file: {queries_file}")
|
||||||
@@ -327,7 +325,7 @@ def main():
|
|||||||
searcher = LeannSearcher(args.index_path)
|
searcher = LeannSearcher(args.index_path)
|
||||||
queries = load_queries(queries_file)
|
queries = load_queries(queries_file)
|
||||||
|
|
||||||
with open(golden_results_file, "r") as f:
|
with open(golden_results_file) as f:
|
||||||
golden_results_data = json.load(f)
|
golden_results_data = json.load(f)
|
||||||
|
|
||||||
num_eval_queries = min(args.num_queries, len(queries))
|
num_eval_queries = min(args.num_queries, len(queries))
|
||||||
@@ -340,10 +338,23 @@ def main():
|
|||||||
for i in range(num_eval_queries):
|
for i in range(num_eval_queries):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
new_results = searcher.search(
|
new_results = searcher.search(
|
||||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
)
|
)
|
||||||
search_times.append(time.time() - start_time)
|
search_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||||
|
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||||
|
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||||
|
answer = chat.ask(
|
||||||
|
queries[i],
|
||||||
|
top_k=args.top_k,
|
||||||
|
complexity=args.ef_search,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
)
|
||||||
|
print(f"Answer: {answer}")
|
||||||
# Correct Recall Calculation: Based on TEXT content
|
# Correct Recall Calculation: Based on TEXT content
|
||||||
new_texts = {result.text for result in new_results}
|
new_texts = {result.text for result in new_results}
|
||||||
|
|
||||||
@@ -1,26 +1,27 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, BitsAndBytesConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoModel
|
||||||
|
|
||||||
# Add MLX imports
|
# Add MLX imports
|
||||||
try:
|
try:
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx_lm.utils import load
|
from mlx_lm.utils import load
|
||||||
|
|
||||||
MLX_AVAILABLE = True
|
MLX_AVAILABLE = True
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||||
MLX_AVAILABLE = False
|
MLX_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkConfig:
|
class BenchmarkConfig:
|
||||||
model_path: str = "facebook/contriever"
|
model_path: str = "facebook/contriever-msmarco"
|
||||||
batch_sizes: List[int] = None
|
batch_sizes: list[int] = None
|
||||||
seq_length: int = 256
|
seq_length: int = 256
|
||||||
num_runs: int = 5
|
num_runs: int = 5
|
||||||
use_fp16: bool = True
|
use_fp16: bool = True
|
||||||
@@ -33,7 +34,8 @@ class BenchmarkConfig:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.batch_sizes is None:
|
if self.batch_sizes is None:
|
||||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
|
||||||
|
|
||||||
class MLXBenchmark:
|
class MLXBenchmark:
|
||||||
"""MLX-specific benchmark for embedding models"""
|
"""MLX-specific benchmark for embedding models"""
|
||||||
@@ -55,11 +57,7 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int):
|
def _create_random_batch(self, batch_size: int):
|
||||||
"""Create random input batches for MLX testing - same as PyTorch"""
|
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||||
return torch.randint(
|
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||||
0, 1000,
|
|
||||||
(batch_size, self.config.seq_length),
|
|
||||||
dtype=torch.long
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
"""Run MLX inference with same input as PyTorch"""
|
"""Run MLX inference with same input as PyTorch"""
|
||||||
@@ -82,12 +80,12 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"MLX inference error: {e}")
|
print(f"MLX inference error: {e}")
|
||||||
return float('inf')
|
return float("inf")
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
"""Run the MLX benchmark across all batch sizes"""
|
"""Run the MLX benchmark across all batch sizes"""
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
@@ -111,10 +109,10 @@ class MLXBenchmark:
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Run benchmark
|
# Run benchmark
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||||
try:
|
try:
|
||||||
elapsed_time = self._run_inference(input_ids)
|
elapsed_time = self._run_inference(input_ids)
|
||||||
if elapsed_time != float('inf'):
|
if elapsed_time != float("inf"):
|
||||||
times.append(elapsed_time)
|
times.append(elapsed_time)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during MLX inference: {e}")
|
print(f"Error during MLX inference: {e}")
|
||||||
@@ -145,16 +143,22 @@ class MLXBenchmark:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
def __init__(self, config: BenchmarkConfig):
|
def __init__(self, config: BenchmarkConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
self.device = (
|
||||||
|
"cuda"
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else "mps"
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
self.model = self._load_model()
|
self.model = self._load_model()
|
||||||
|
|
||||||
def _load_model(self) -> nn.Module:
|
def _load_model(self) -> nn.Module:
|
||||||
print(f"Loading model from {self.config.model_path}...")
|
print(f"Loading model from {self.config.model_path}...")
|
||||||
|
|
||||||
|
|
||||||
model = AutoModel.from_pretrained(self.config.model_path)
|
model = AutoModel.from_pretrained(self.config.model_path)
|
||||||
if self.config.use_fp16:
|
if self.config.use_fp16:
|
||||||
model = model.half()
|
model = model.half()
|
||||||
@@ -166,23 +170,30 @@ class Benchmark:
|
|||||||
|
|
||||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||||
return torch.randint(
|
return torch.randint(
|
||||||
0, 1000,
|
0,
|
||||||
|
1000,
|
||||||
(batch_size, self.config.seq_length),
|
(batch_size, self.config.seq_length),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
# print shape of input_ids and attention_mask
|
||||||
|
print(f"input_ids shape: {input_ids.shape}")
|
||||||
|
print(f"attention_mask shape: {attention_mask.shape}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.synchronize()
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
return end_time - start_time
|
return end_time - start_time
|
||||||
|
|
||||||
def run(self) -> Dict[int, Dict[str, float]]:
|
def run(self) -> dict[int, dict[str, float]]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -194,7 +205,7 @@ class Benchmark:
|
|||||||
|
|
||||||
input_ids = self._create_random_batch(batch_size)
|
input_ids = self._create_random_batch(batch_size)
|
||||||
|
|
||||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||||
try:
|
try:
|
||||||
elapsed_time = self._run_inference(input_ids)
|
elapsed_time = self._run_inference(input_ids)
|
||||||
times.append(elapsed_time)
|
times.append(elapsed_time)
|
||||||
@@ -219,7 +230,7 @@ class Benchmark:
|
|||||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||||
else:
|
else:
|
||||||
peak_memory_gb = 0.0
|
peak_memory_gb = 0.0
|
||||||
|
|
||||||
@@ -228,6 +239,7 @@ class Benchmark:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark():
|
def run_benchmark():
|
||||||
"""Main function to run the benchmark with optimized parameters."""
|
"""Main function to run the benchmark with optimized parameters."""
|
||||||
config = BenchmarkConfig()
|
config = BenchmarkConfig()
|
||||||
@@ -242,16 +254,13 @@ def run_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": max_throughput,
|
"max_throughput": max_throughput,
|
||||||
"avg_throughput": avg_throughput,
|
"avg_throughput": avg_throughput,
|
||||||
"results": results
|
"results": results,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Benchmark failed: {e}")
|
print(f"Benchmark failed: {e}")
|
||||||
return {
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
"max_throughput": 0.0,
|
|
||||||
"avg_throughput": 0.0,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_mlx_benchmark():
|
def run_mlx_benchmark():
|
||||||
"""Run MLX-specific benchmark"""
|
"""Run MLX-specific benchmark"""
|
||||||
@@ -260,13 +269,10 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": 0.0,
|
"max_throughput": 0.0,
|
||||||
"avg_throughput": 0.0,
|
"avg_throughput": 0.0,
|
||||||
"error": "MLX not available"
|
"error": "MLX not available",
|
||||||
}
|
}
|
||||||
|
|
||||||
config = BenchmarkConfig(
|
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||||
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
|
|
||||||
use_mlx=True
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
benchmark = MLXBenchmark(config)
|
benchmark = MLXBenchmark(config)
|
||||||
@@ -276,7 +282,7 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": 0.0,
|
"max_throughput": 0.0,
|
||||||
"avg_throughput": 0.0,
|
"avg_throughput": 0.0,
|
||||||
"error": "No valid results"
|
"error": "No valid results",
|
||||||
}
|
}
|
||||||
|
|
||||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||||
@@ -285,16 +291,13 @@ def run_mlx_benchmark():
|
|||||||
return {
|
return {
|
||||||
"max_throughput": max_throughput,
|
"max_throughput": max_throughput,
|
||||||
"avg_throughput": avg_throughput,
|
"avg_throughput": avg_throughput,
|
||||||
"results": results
|
"results": results,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"MLX benchmark failed: {e}")
|
print(f"MLX benchmark failed: {e}")
|
||||||
return {
|
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||||
"max_throughput": 0.0,
|
|
||||||
"avg_throughput": 0.0,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("=== PyTorch Benchmark ===")
|
print("=== PyTorch Benchmark ===")
|
||||||
@@ -308,7 +311,7 @@ if __name__ == "__main__":
|
|||||||
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||||
|
|
||||||
# Compare results
|
# Compare results
|
||||||
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
|
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||||
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
|
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||||
print(f"\n=== Comparison ===")
|
print("\n=== Comparison ===")
|
||||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||||
143
benchmarks/update/README.md
Normal file
143
benchmarks/update/README.md
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# Update Benchmarks
|
||||||
|
|
||||||
|
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||||
|
search” pipeline under different assumptions:
|
||||||
|
|
||||||
|
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||||
|
settings influence incremental `add()` latency when embeddings are fetched
|
||||||
|
over the ZMQ embedding server.
|
||||||
|
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||||
|
against an offline approach that keeps the graph static and fuses results.
|
||||||
|
|
||||||
|
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||||
|
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||||
|
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||||
|
|
||||||
|
## Benchmarks
|
||||||
|
|
||||||
|
### 1. HNSW RNG Recompute Benchmark
|
||||||
|
|
||||||
|
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||||
|
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||||
|
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||||
|
is enabled:
|
||||||
|
|
||||||
|
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||||
|
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||||
|
| `baseline` | Enabled | Enabled | Enabled |
|
||||||
|
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||||
|
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||||
|
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||||
|
|
||||||
|
For each scenario the script:
|
||||||
|
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||||
|
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||||
|
3. Appends the requested updates using the scenario’s RNG flags.
|
||||||
|
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||||
|
timings before appending a row to the CSV output.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||||
|
LEANN_LOG_LEVEL=INFO \
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--runs 1 \
|
||||||
|
--index-path .leann/bench/test.leann \
|
||||||
|
--initial-files data/PrideandPrejudice.txt \
|
||||||
|
--update-files data/huawei_pangu.md \
|
||||||
|
--max-initial 300 \
|
||||||
|
--max-updates 1 \
|
||||||
|
--add-timeout 120
|
||||||
|
```
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||||
|
(including ms/passage) for each run.
|
||||||
|
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||||
|
`LEANN_HNSW_LOG_PATH`).
|
||||||
|
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||||
|
|
||||||
|
### 2. Sequential vs. Offline Update Benchmark
|
||||||
|
|
||||||
|
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||||
|
same dataset:
|
||||||
|
|
||||||
|
- **Scenario A – Sequential Update**
|
||||||
|
- Start an embedding server.
|
||||||
|
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||||
|
mutates the HNSW graph.
|
||||||
|
- After all inserts, run a search on the updated graph.
|
||||||
|
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||||
|
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||||
|
latency.
|
||||||
|
|
||||||
|
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||||
|
- Stop Scenario A’s server and start a fresh embedding server.
|
||||||
|
- Spawn two threads: one generates embeddings for the new passages offline
|
||||||
|
(graph unchanged); the other computes the query embedding and searches the
|
||||||
|
existing graph.
|
||||||
|
- Merge offline similarities with the graph search results to emulate late
|
||||||
|
fusion, then report the merged top‑k preview.
|
||||||
|
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||||
|
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||||
|
|
||||||
|
**Run (both scenarios):**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 \
|
||||||
|
--num-updates 1
|
||||||
|
```
|
||||||
|
|
||||||
|
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||||
|
print timing summaries to stdout and append the results to CSV.
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||||
|
Scenario A and B.
|
||||||
|
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||||
|
checks.
|
||||||
|
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||||
|
|
||||||
|
### 3. Visualisation
|
||||||
|
|
||||||
|
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||||
|
benchmark into a single two-panel plot.
|
||||||
|
|
||||||
|
**Run:**
|
||||||
|
```bash
|
||||||
|
uv run -m benchmarks.update.plot_bench_results \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
```
|
||||||
|
|
||||||
|
**Options:**
|
||||||
|
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||||
|
- `--csv` – RNG benchmark results CSV (left panel).
|
||||||
|
- `--csv-right` – Update strategy results CSV (right panel).
|
||||||
|
- `--out` – Output image path (PNG/PDF supported).
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||||
|
suites.
|
||||||
|
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||||
|
slides/papers.
|
||||||
|
|
||||||
|
## Parameters & Environment
|
||||||
|
|
||||||
|
### Common CLI Flags
|
||||||
|
- `--max-initial` – Number of initial passages used to seed the index.
|
||||||
|
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||||
|
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||||
|
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||||
|
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||||
|
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||||
|
execution of the embedding model.
|
||||||
|
|
||||||
|
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||||
|
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||||
|
fusion better match your latency/accuracy trade-offs.
|
||||||
16
benchmarks/update/__init__.py
Normal file
16
benchmarks/update/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Benchmarks for LEANN update workflows."""
|
||||||
|
|
||||||
|
# Expose helper to locate repository root for other modules that need it.
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def find_repo_root() -> Path:
|
||||||
|
"""Return the project root containing pyproject.toml."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
return current.parents[1]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["find_repo_root"]
|
||||||
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
@@ -0,0 +1,804 @@
|
|||||||
|
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||||
|
embedding recomputation.
|
||||||
|
|
||||||
|
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||||
|
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||||
|
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||||
|
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||||
|
|
||||||
|
Example usage (run from the repo root; downloads the model on first run)::
|
||||||
|
|
||||||
|
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||||
|
--index-path .leann/bench/leann-demo.leann \
|
||||||
|
--runs 1
|
||||||
|
|
||||||
|
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||||
|
if you want a larger or different workload, and change the embedding model via
|
||||||
|
``--model-name``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import msgpack
|
||||||
|
import numpy as np
|
||||||
|
import zmq
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||||
|
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_update_with_mode(
|
||||||
|
index_path: Path,
|
||||||
|
new_chunks: list[dict[str, Any]],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
disable_forward_rng: bool,
|
||||||
|
disable_reverse_rng: bool,
|
||||||
|
server_port: int,
|
||||||
|
add_timeout: int,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||||
|
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||||
|
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
|
||||||
|
with open(offset_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
existing_ids = set(offset_map.keys())
|
||||||
|
|
||||||
|
valid_chunks: list[dict[str, Any]] = []
|
||||||
|
for chunk in new_chunks:
|
||||||
|
text = chunk.get("text", "")
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
continue
|
||||||
|
metadata = chunk.setdefault("metadata", {})
|
||||||
|
passage_id = chunk.get("id") or metadata.get("id")
|
||||||
|
if passage_id and passage_id in existing_ids:
|
||||||
|
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||||
|
valid_chunks.append(chunk)
|
||||||
|
|
||||||
|
if not valid_chunks:
|
||||||
|
raise ValueError("No valid chunks to append.")
|
||||||
|
|
||||||
|
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||||
|
embeddings = compute_embeddings(
|
||||||
|
texts_to_embed,
|
||||||
|
model_name,
|
||||||
|
mode=embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
embeddings = embeddings / norms
|
||||||
|
|
||||||
|
index = faiss.read_index(str(index_file))
|
||||||
|
index.is_recompute = True
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
try:
|
||||||
|
storage_index.ntotal = index.ntotal
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||||
|
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||||
|
if ef_construction is not None:
|
||||||
|
index.hnsw.efConstruction = ef_construction
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||||
|
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||||
|
logger.info(
|
||||||
|
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||||
|
disable_forward_rng,
|
||||||
|
disable_reverse_rng,
|
||||||
|
applied_forward,
|
||||||
|
applied_reverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_id = index.ntotal
|
||||||
|
for offset, chunk in enumerate(valid_chunks):
|
||||||
|
new_id = str(base_id + offset)
|
||||||
|
chunk.setdefault("metadata", {})["id"] = new_id
|
||||||
|
chunk["id"] = new_id
|
||||||
|
|
||||||
|
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||||
|
offset_map_backup = offset_map.copy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for chunk in valid_chunks:
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"id": chunk["id"],
|
||||||
|
"text": chunk["text"],
|
||||||
|
"metadata": chunk.get("metadata", {}),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[chunk["id"]] = offset
|
||||||
|
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
server_started, actual_port = server_manager.start_server(
|
||||||
|
port=server_port,
|
||||||
|
model_name=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
)
|
||||||
|
if not server_started:
|
||||||
|
raise RuntimeError("Failed to start embedding server.")
|
||||||
|
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(actual_port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(actual_port)
|
||||||
|
|
||||||
|
_warmup_embedding_server(actual_port)
|
||||||
|
|
||||||
|
total_start = time.time()
|
||||||
|
add_elapsed = 0.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
import signal
|
||||||
|
|
||||||
|
def _timeout_handler(signum, frame):
|
||||||
|
raise TimeoutError("incremental add timed out")
|
||||||
|
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||||
|
signal.alarm(add_timeout)
|
||||||
|
|
||||||
|
add_start = time.time()
|
||||||
|
for i in range(embeddings.shape[0]):
|
||||||
|
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||||
|
add_elapsed = time.time() - add_start
|
||||||
|
if add_timeout > 0:
|
||||||
|
signal.alarm(0)
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
if passages_file.exists():
|
||||||
|
with open(passages_file, "rb+") as f:
|
||||||
|
f.truncate(rollback_size)
|
||||||
|
with open(offset_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map_backup, f)
|
||||||
|
raise
|
||||||
|
|
||||||
|
prune_hnsw_embeddings_inplace(str(index_file))
|
||||||
|
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||||
|
try:
|
||||||
|
index.hnsw.set_disable_rng_during_add(False)
|
||||||
|
index.hnsw.set_disable_reverse_prune(False)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
total_elapsed = time.time() - total_start
|
||||||
|
|
||||||
|
return total_elapsed, add_elapsed
|
||||||
|
|
||||||
|
|
||||||
|
def _total_zmq_nodes(log_path: Path) -> int:
|
||||||
|
if not log_path.exists():
|
||||||
|
return 0
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
text = log_file.read()
|
||||||
|
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||||
|
|
||||||
|
|
||||||
|
def _warmup_embedding_server(port: int) -> None:
|
||||||
|
"""Send a dummy REQ so the embedding server loads its model."""
|
||||||
|
ctx = zmq.Context()
|
||||||
|
try:
|
||||||
|
sock = ctx.socket(zmq.REQ)
|
||||||
|
sock.setsockopt(zmq.LINGER, 0)
|
||||||
|
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||||
|
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||||
|
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||||
|
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||||
|
sock.send(payload)
|
||||||
|
try:
|
||||||
|
sock.recv()
|
||||||
|
except zmq.error.Again:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
sock.close()
|
||||||
|
ctx.term()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/leann-demo.leann"),
|
||||||
|
help="Output index base path (without extension).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
help="Files used to build the initial index.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
help="Files appended during the benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
help="Embedding model used for build/update.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--embedding-mode",
|
||||||
|
default="sentence-transformers",
|
||||||
|
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
help="Distance metric for HNSW backend.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ef-construction",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="efConstruction setting for initial build.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-port",
|
||||||
|
type=int,
|
||||||
|
default=5557,
|
||||||
|
help="Port for the real embedding server.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-initial",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help="Optional cap on initial passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-updates",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Optional cap on update passages (after chunking).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--add-timeout",
|
||||||
|
type=int,
|
||||||
|
default=900,
|
||||||
|
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plot-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("bench_latency.png"),
|
||||||
|
help="Where to save the latency bar plot.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Where to append per-scenario results as CSV.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||||
|
|
||||||
|
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
|
||||||
|
scenarios = [
|
||||||
|
("baseline", False, False, True),
|
||||||
|
("no_cache_baseline", False, False, False),
|
||||||
|
("disable_forward_rng", True, False, True),
|
||||||
|
("disable_forward_and_reverse_rng", True, True, True),
|
||||||
|
]
|
||||||
|
|
||||||
|
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||||
|
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
import csv
|
||||||
|
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"cache_enabled",
|
||||||
|
"ef_construction",
|
||||||
|
"max_initial",
|
||||||
|
"max_updates",
|
||||||
|
"total_time_s",
|
||||||
|
"add_only_s",
|
||||||
|
"latency_ms_per_passage",
|
||||||
|
"zmq_nodes",
|
||||||
|
"stageA_time_s",
|
||||||
|
"stageBC_time_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
# Create CSV with header if missing
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for run in range(args.runs):
|
||||||
|
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||||
|
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||||
|
print(f"\nScenario: {name}")
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
if log_path.exists():
|
||||||
|
try:
|
||||||
|
log_path.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||||
|
args.index_path,
|
||||||
|
update_chunks,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
disable_forward,
|
||||||
|
disable_reverse,
|
||||||
|
args.server_port,
|
||||||
|
args.add_timeout,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
except TimeoutError as exc:
|
||||||
|
print(f"Scenario {name} timed out: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||||
|
if curr_size < prev_size:
|
||||||
|
prev_size = 0
|
||||||
|
zmq_count = 0
|
||||||
|
if log_path.exists():
|
||||||
|
with log_path.open("r", encoding="utf-8") as log_file:
|
||||||
|
log_file.seek(prev_size)
|
||||||
|
new_entries = log_file.read()
|
||||||
|
zmq_count = sum(
|
||||||
|
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||||
|
)
|
||||||
|
stageA = sum(
|
||||||
|
float(x)
|
||||||
|
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
stageBC = sum(
|
||||||
|
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stageA = 0.0
|
||||||
|
stageBC = 0.0
|
||||||
|
|
||||||
|
per_chunk = add_elapsed / len(update_chunks)
|
||||||
|
print(
|
||||||
|
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||||
|
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||||
|
)
|
||||||
|
print(f"ZMQ node fetch total: {zmq_count}")
|
||||||
|
results_total[name].append(total_elapsed)
|
||||||
|
results_add[name].append(add_elapsed)
|
||||||
|
results_zmq[name].append(zmq_count)
|
||||||
|
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||||
|
results_stageA[name].append(stageA)
|
||||||
|
results_stageBC[name].append(stageBC)
|
||||||
|
|
||||||
|
# Append row to CSV
|
||||||
|
if args.csv_path:
|
||||||
|
row = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": name,
|
||||||
|
"cache_enabled": 1 if cache_enabled else 0,
|
||||||
|
"ef_construction": args.ef_construction,
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"max_updates": args.max_updates,
|
||||||
|
"total_time_s": round(total_elapsed, 6),
|
||||||
|
"add_only_s": round(add_elapsed, 6),
|
||||||
|
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||||
|
"zmq_nodes": int(zmq_count),
|
||||||
|
"stageA_time_s": round(stageA, 6),
|
||||||
|
"stageBC_time_s": round(stageBC, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
for name in results_add:
|
||||||
|
add_values = results_add[name]
|
||||||
|
total_values = results_total[name]
|
||||||
|
zmq_values = results_zmq[name]
|
||||||
|
latency_values = results_ms_per_passage[name]
|
||||||
|
if not add_values:
|
||||||
|
print(f"{name}: no successful runs")
|
||||||
|
continue
|
||||||
|
avg_add = sum(add_values) / len(add_values)
|
||||||
|
avg_total = sum(total_values) / len(total_values)
|
||||||
|
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||||
|
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||||
|
runs = len(add_values)
|
||||||
|
print(
|
||||||
|
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||||
|
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.plot_path:
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
labels = [name for name, *_ in scenarios]
|
||||||
|
values = [
|
||||||
|
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||||
|
if results_ms_per_passage[name]
|
||||||
|
else 0.0
|
||||||
|
for name in labels
|
||||||
|
]
|
||||||
|
|
||||||
|
def _auto_cap(vals: list[float]) -> float | None:
|
||||||
|
s = sorted(vals, reverse=True)
|
||||||
|
if len(s) < 2:
|
||||||
|
return None
|
||||||
|
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||||
|
return s[1] * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||||
|
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.4, 5.0),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||||
|
)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap * 0.02,
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False)
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_bottom.set_xticks(range(len(labels)))
|
||||||
|
ax_bottom.set_xticklabels(labels)
|
||||||
|
ax = ax_bottom
|
||||||
|
else:
|
||||||
|
cap = args.cap_y or _auto_cap(values)
|
||||||
|
plt.figure(figsize=(7.2, 4.2))
|
||||||
|
ax = plt.gca()
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||||
|
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(b[0])
|
||||||
|
if v > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(v),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
ax.plot(
|
||||||
|
[0.02 - 0.02, 0.02 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
ax.plot(
|
||||||
|
[0.98 - 0.02, 0.98 + 0.02],
|
||||||
|
[0.98 + 0.02, 0.98 - 0.02],
|
||||||
|
transform=ax.transAxes,
|
||||||
|
color="k",
|
||||||
|
lw=1,
|
||||||
|
)
|
||||||
|
if any(v > cap for v in values):
|
||||||
|
ax.legend(
|
||||||
|
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||||
|
)
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels)
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||||
|
|
||||||
|
plt.ylabel("Average add latency (ms per passage)")
|
||||||
|
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(args.plot_path)
|
||||||
|
print(f"Saved latency bar plot to {args.plot_path}")
|
||||||
|
# ZMQ time split (Stage A vs B/C)
|
||||||
|
try:
|
||||||
|
plt.figure(figsize=(6, 4))
|
||||||
|
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||||
|
bc_vals = [
|
||||||
|
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||||
|
]
|
||||||
|
ind = range(len(labels))
|
||||||
|
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||||
|
plt.bar(
|
||||||
|
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||||
|
)
|
||||||
|
plt.xticks(list(ind), labels, rotation=10)
|
||||||
|
plt.ylabel("Server ZMQ time (s)")
|
||||||
|
plt.title(
|
||||||
|
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||||
|
)
|
||||||
|
plt.legend()
|
||||||
|
out2 = args.plot_path.with_name(
|
||||||
|
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||||
|
)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out2)
|
||||||
|
print(f"Saved ZMQ time split plot to {out2}")
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to plot ZMQ split:", e)
|
||||||
|
except ImportError:
|
||||||
|
print("matplotlib not available; skipping plot generation")
|
||||||
|
|
||||||
|
# leave the last build on disk for inspection
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/bench_results.csv
Normal file
5
benchmarks/update/bench_results.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
@@ -0,0 +1,704 @@
|
|||||||
|
"""
|
||||||
|
Compare two latency models for small incremental updates vs. search:
|
||||||
|
|
||||||
|
Scenario A (sequential update then search):
|
||||||
|
- Build initial HNSW (is_recompute=True)
|
||||||
|
- Start embedding server (ZMQ) for recompute
|
||||||
|
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||||
|
- Then run a search query on the updated index
|
||||||
|
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||||
|
|
||||||
|
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||||
|
- Do NOT insert the N passages into the graph
|
||||||
|
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||||
|
embedding and run a search on the existing index
|
||||||
|
- After both finish, compute similarity between the query embedding and the N
|
||||||
|
new passage embeddings, merge with the index search results by score, and
|
||||||
|
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||||
|
|
||||||
|
This script reuses the model/data loading conventions of
|
||||||
|
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||||
|
comparison for the two execution strategies above.
|
||||||
|
|
||||||
|
Example (from the repository root):
|
||||||
|
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||||
|
--index-path .leann/bench/offline_vs_update.leann \
|
||||||
|
--max-initial 300 --num-updates 5 --k 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import psutil # type: ignore
|
||||||
|
from leann.api import LeannBuilder
|
||||||
|
|
||||||
|
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||||
|
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||||
|
|
||||||
|
from leann.embedding_compute import compute_embeddings
|
||||||
|
from leann.embedding_server_manager import EmbeddingServerManager
|
||||||
|
from leann.registry import register_project_directory
|
||||||
|
from leann_backend_hnsw import faiss # type: ignore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
if not logging.getLogger().handlers:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_repo_root() -> Path:
|
||||||
|
"""Locate project root by walking up until pyproject.toml is found."""
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for parent in current.parents:
|
||||||
|
if (parent / "pyproject.toml").exists():
|
||||||
|
return parent
|
||||||
|
# Fallback: assume repo is two levels up (../..)
|
||||||
|
return current.parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
REPO_ROOT = _find_repo_root()
|
||||||
|
if str(REPO_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
|
|
||||||
|
from apps.chunking import create_text_chunks # noqa: E402
|
||||||
|
|
||||||
|
DEFAULT_INITIAL_FILES = [
|
||||||
|
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||||
|
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||||
|
]
|
||||||
|
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||||
|
|
||||||
|
|
||||||
|
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||||
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for path in paths:
|
||||||
|
p = path.expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Input path not found: {p}")
|
||||||
|
if p.is_dir():
|
||||||
|
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
else:
|
||||||
|
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||||
|
documents.extend(reader.load_data(show_progress=True))
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
chunks = create_text_chunks(
|
||||||
|
documents,
|
||||||
|
chunk_size=512,
|
||||||
|
chunk_overlap=128,
|
||||||
|
use_ast_chunking=False,
|
||||||
|
)
|
||||||
|
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||||
|
if limit is not None:
|
||||||
|
cleaned = cleaned[:limit]
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_index_dir(index_path: Path) -> None:
|
||||||
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_index_files(index_path: Path) -> None:
|
||||||
|
parent = index_path.parent
|
||||||
|
if not parent.exists():
|
||||||
|
return
|
||||||
|
stem = index_path.stem
|
||||||
|
for file in parent.glob(f"{stem}*"):
|
||||||
|
if file.is_file():
|
||||||
|
file.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
def build_initial_index(
|
||||||
|
index_path: Path,
|
||||||
|
paragraphs: list[str],
|
||||||
|
model_name: str,
|
||||||
|
embedding_mode: str,
|
||||||
|
distance_metric: str,
|
||||||
|
ef_construction: int,
|
||||||
|
) -> None:
|
||||||
|
builder = LeannBuilder(
|
||||||
|
backend_name="hnsw",
|
||||||
|
embedding_model=model_name,
|
||||||
|
embedding_mode=embedding_mode,
|
||||||
|
is_compact=False,
|
||||||
|
is_recompute=True,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
backend_kwargs={
|
||||||
|
"distance_metric": distance_metric,
|
||||||
|
"is_compact": False,
|
||||||
|
"is_recompute": True,
|
||||||
|
"efConstruction": ef_construction,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for idx, passage in enumerate(paragraphs):
|
||||||
|
builder.add_text(passage, metadata={"id": str(idx)})
|
||||||
|
builder.build_index(str(index_path))
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||||
|
if metric == "cosine":
|
||||||
|
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||||
|
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||||
|
norms[norms == 0] = 1
|
||||||
|
vecs = vecs / norms
|
||||||
|
return vecs
|
||||||
|
|
||||||
|
|
||||||
|
def _read_index_for_search(index_path: Path) -> Any:
|
||||||
|
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||||
|
# Force-disable experimental disk cache when loading the index so that
|
||||||
|
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||||
|
cfg = faiss.HNSWIndexConfig()
|
||||||
|
cfg.is_recompute = True
|
||||||
|
if hasattr(cfg, "disk_cache_ratio"):
|
||||||
|
cfg.disk_cache_ratio = 0.0
|
||||||
|
if hasattr(cfg, "external_storage_path"):
|
||||||
|
cfg.external_storage_path = None
|
||||||
|
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||||
|
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||||
|
# ensure recompute mode persists after reload
|
||||||
|
try:
|
||||||
|
index.is_recompute = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
actual_ntotal = index.hnsw.levels.size()
|
||||||
|
except AttributeError:
|
||||||
|
actual_ntotal = index.ntotal
|
||||||
|
if actual_ntotal != index.ntotal:
|
||||||
|
print(
|
||||||
|
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
index.ntotal = actual_ntotal
|
||||||
|
if getattr(index, "storage", None) is None:
|
||||||
|
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
|
storage_index = faiss.IndexFlatIP(index.d)
|
||||||
|
else:
|
||||||
|
storage_index = faiss.IndexFlatL2(index.d)
|
||||||
|
index.storage = storage_index
|
||||||
|
index.own_fields = True
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def _append_passages_for_updates(
|
||||||
|
meta_path: Path,
|
||||||
|
start_id: int,
|
||||||
|
texts: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
index_dir = meta_path.parent
|
||||||
|
meta_name = meta_path.name
|
||||||
|
if not meta_name.endswith(".meta.json"):
|
||||||
|
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||||
|
index_base = meta_name[: -len(".meta.json")]
|
||||||
|
|
||||||
|
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||||
|
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||||
|
|
||||||
|
if not passages_file.exists() or not offsets_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Passage store missing; cannot register update passages for recompute mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(offsets_file, "rb") as f:
|
||||||
|
offset_map: dict[str, int] = pickle.load(f)
|
||||||
|
|
||||||
|
assigned_ids: list[str] = []
|
||||||
|
with open(passages_file, "a", encoding="utf-8") as f:
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
passage_id = str(start_id + i)
|
||||||
|
offset = f.tell()
|
||||||
|
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||||
|
f.write("\n")
|
||||||
|
offset_map[passage_id] = offset
|
||||||
|
assigned_ids.append(passage_id)
|
||||||
|
|
||||||
|
with open(offsets_file, "wb") as f:
|
||||||
|
pickle.dump(offset_map, f)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(meta_path, encoding="utf-8") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
meta = {}
|
||||||
|
meta["total_passages"] = len(offset_map)
|
||||||
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
|
return assigned_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||||
|
distances = np.zeros((1, k), dtype=np.float32)
|
||||||
|
indices = np.zeros((1, k), dtype=np.int64)
|
||||||
|
index.search(
|
||||||
|
1,
|
||||||
|
faiss.swig_ptr(q),
|
||||||
|
k,
|
||||||
|
faiss.swig_ptr(distances),
|
||||||
|
faiss.swig_ptr(indices),
|
||||||
|
)
|
||||||
|
return distances[0], indices[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _score_for_metric(dist: float, metric: str) -> float:
|
||||||
|
# Convert FAISS distance to a "higher is better" score
|
||||||
|
if metric in ("mips", "cosine"):
|
||||||
|
return float(dist)
|
||||||
|
# l2 distance (smaller better) -> negative distance as score
|
||||||
|
return -float(dist)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray],
|
||||||
|
offline_scores: list[tuple[int, float]],
|
||||||
|
k: int,
|
||||||
|
metric: str,
|
||||||
|
) -> list[tuple[str, float]]:
|
||||||
|
distances, indices = index_results
|
||||||
|
merged: list[tuple[str, float]] = []
|
||||||
|
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||||
|
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||||
|
for j, s in offline_scores:
|
||||||
|
merged.append((f"offline:{j}", s))
|
||||||
|
merged.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return merged[:k]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScenarioResult:
|
||||||
|
name: str
|
||||||
|
update_total_s: float
|
||||||
|
search_s: float
|
||||||
|
overall_s: float
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--index-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--initial-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_INITIAL_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-files",
|
||||||
|
nargs="*",
|
||||||
|
type=Path,
|
||||||
|
default=DEFAULT_UPDATE_FILES,
|
||||||
|
)
|
||||||
|
parser.add_argument("--max-initial", type=int, default=300)
|
||||||
|
parser.add_argument("--num-updates", type=int, default=5)
|
||||||
|
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||||
|
parser.add_argument(
|
||||||
|
"--query",
|
||||||
|
type=str,
|
||||||
|
default="neural network",
|
||||||
|
help="Query text used for the search benchmark.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--server-port", type=int, default=5557)
|
||||||
|
parser.add_argument("--add-timeout", type=int, default=600)
|
||||||
|
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||||
|
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||||
|
parser.add_argument(
|
||||||
|
"--distance-metric",
|
||||||
|
default="mips",
|
||||||
|
choices=["mips", "l2", "cosine"],
|
||||||
|
)
|
||||||
|
parser.add_argument("--ef-construction", type=int, default=200)
|
||||||
|
parser.add_argument(
|
||||||
|
"--only",
|
||||||
|
choices=["A", "B", "both"],
|
||||||
|
default="both",
|
||||||
|
help="Run only Scenario A, Scenario B, or both",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--csv-path",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Where to append results (CSV).",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
register_project_directory(REPO_ROOT)
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||||
|
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||||
|
if not update_paragraphs:
|
||||||
|
raise ValueError("No update passages loaded from --update-files")
|
||||||
|
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||||
|
if len(update_paragraphs) < args.num_updates:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ensure_index_dir(args.index_path)
|
||||||
|
cleanup_index_files(args.index_path)
|
||||||
|
|
||||||
|
# Build initial index
|
||||||
|
build_initial_index(
|
||||||
|
args.index_path,
|
||||||
|
initial_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
args.embedding_mode,
|
||||||
|
args.distance_metric,
|
||||||
|
args.ef_construction,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare index object and meta
|
||||||
|
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||||
|
index = _read_index_for_search(args.index_path)
|
||||||
|
|
||||||
|
# CSV setup
|
||||||
|
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
if args.csv_path:
|
||||||
|
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
csv_fields = [
|
||||||
|
"run_id",
|
||||||
|
"scenario",
|
||||||
|
"max_initial",
|
||||||
|
"num_updates",
|
||||||
|
"k",
|
||||||
|
"total_time_s",
|
||||||
|
"add_total_s",
|
||||||
|
"search_time_s",
|
||||||
|
"emb_time_s",
|
||||||
|
"makespan_s",
|
||||||
|
"model_name",
|
||||||
|
"embedding_mode",
|
||||||
|
"distance_metric",
|
||||||
|
]
|
||||||
|
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||||
|
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
# Debug: list existing HNSW server PIDs before starting
|
||||||
|
try:
|
||||||
|
existing = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if existing:
|
||||||
|
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||||
|
for p in existing:
|
||||||
|
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||||
|
except Exception as _e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
add_total = 0.0
|
||||||
|
search_after_add = 0.0
|
||||||
|
total_seq = 0.0
|
||||||
|
port_a = None
|
||||||
|
if args.only in ("A", "both"):
|
||||||
|
# Scenario A: sequential update then search
|
||||||
|
start_id = index.ntotal
|
||||||
|
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||||
|
if assigned_ids:
|
||||||
|
logger.debug(
|
||||||
|
"Registered %d update passages starting at id %s",
|
||||||
|
len(assigned_ids),
|
||||||
|
assigned_ids[0],
|
||||||
|
)
|
||||||
|
server_manager = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
ok, port = server_manager.start_server(
|
||||||
|
port=args.server_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError("Failed to start embedding server")
|
||||||
|
try:
|
||||||
|
# Set ZMQ port for recompute mode
|
||||||
|
if hasattr(index.hnsw, "set_zmq_port"):
|
||||||
|
index.hnsw.set_zmq_port(port)
|
||||||
|
elif hasattr(index, "set_zmq_port"):
|
||||||
|
index.set_zmq_port(port)
|
||||||
|
|
||||||
|
# Start A overall timer BEFORE computing update embeddings
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
# Compute embeddings for updates (counted into A's overall)
|
||||||
|
t_emb0 = time.time()
|
||||||
|
upd_embs = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time_updates = time.time() - t_emb0
|
||||||
|
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||||
|
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||||
|
|
||||||
|
# Perform sequential adds
|
||||||
|
for i in range(upd_embs.shape[0]):
|
||||||
|
t_add0 = time.time()
|
||||||
|
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||||
|
add_total += time.time() - t_add0
|
||||||
|
# Don't persist index after adds to avoid contaminating Scenario B
|
||||||
|
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||||
|
# faiss.write_index(index, str(index_file))
|
||||||
|
|
||||||
|
# Search after updates
|
||||||
|
q_emb = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||||
|
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||||
|
|
||||||
|
# Warm up search with a dummy query first
|
||||||
|
print("[DEBUG] Warming up search...")
|
||||||
|
_ = _search(index, q_emb, 1)
|
||||||
|
|
||||||
|
t_s0 = time.time()
|
||||||
|
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||||
|
search_after_add = time.time() - t_s0
|
||||||
|
total_seq = time.time() - t0
|
||||||
|
finally:
|
||||||
|
server_manager.stop_server()
|
||||||
|
port_a = port
|
||||||
|
|
||||||
|
print("\n=== Scenario A: update->search (sequential) ===")
|
||||||
|
# emb_time_updates is defined only when A runs
|
||||||
|
try:
|
||||||
|
_emb_a = emb_time_updates
|
||||||
|
except NameError:
|
||||||
|
_emb_a = 0.0
|
||||||
|
print(
|
||||||
|
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||||
|
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||||
|
)
|
||||||
|
# CSV row for A
|
||||||
|
if args.csv_path:
|
||||||
|
row_a = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "A",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": round(total_seq, 6),
|
||||||
|
"add_total_s": round(add_total, 6),
|
||||||
|
"search_time_s": round(search_after_add, 6),
|
||||||
|
"emb_time_s": round(_emb_a, 6),
|
||||||
|
"makespan_s": 0.0,
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_a)
|
||||||
|
|
||||||
|
# Verify server cleanup
|
||||||
|
try:
|
||||||
|
# short sleep to allow signal handling to finish
|
||||||
|
time.sleep(0.5)
|
||||||
|
leftovers = [
|
||||||
|
p
|
||||||
|
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||||
|
if any(
|
||||||
|
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||||
|
for arg in (p.info.get("cmdline") or [])
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if leftovers:
|
||||||
|
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||||
|
for p in leftovers:
|
||||||
|
print(
|
||||||
|
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||||
|
if args.only in ("B", "both"):
|
||||||
|
# ensure a server is available for recompute search
|
||||||
|
server_manager_b = EmbeddingServerManager(
|
||||||
|
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||||
|
)
|
||||||
|
requested_port = args.server_port if port_a is None else port_a
|
||||||
|
ok_b, port_b = server_manager_b.start_server(
|
||||||
|
port=requested_port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
embedding_mode=args.embedding_mode,
|
||||||
|
passages_file=str(meta_path),
|
||||||
|
distance_metric=args.distance_metric,
|
||||||
|
)
|
||||||
|
if not ok_b:
|
||||||
|
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||||
|
|
||||||
|
# Wait for server to fully initialize
|
||||||
|
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read the index first
|
||||||
|
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||||
|
|
||||||
|
# Then configure ZMQ port on the correct index object
|
||||||
|
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||||
|
index_no_update.hnsw.set_zmq_port(port_b)
|
||||||
|
elif hasattr(index_no_update, "set_zmq_port"):
|
||||||
|
index_no_update.set_zmq_port(port_b)
|
||||||
|
|
||||||
|
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||||
|
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||||
|
logger.info("Warming up embedding model for Scenario B...")
|
||||||
|
_ = compute_embeddings(
|
||||||
|
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare worker A: compute embeddings for the same N passages
|
||||||
|
emb_time = 0.0
|
||||||
|
updates_embs_offline: np.ndarray | None = None
|
||||||
|
|
||||||
|
def _worker_emb():
|
||||||
|
nonlocal emb_time, updates_embs_offline
|
||||||
|
t = time.time()
|
||||||
|
updates_embs_offline = compute_embeddings(
|
||||||
|
update_paragraphs,
|
||||||
|
args.model_name,
|
||||||
|
mode=args.embedding_mode,
|
||||||
|
is_build=False,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
emb_time = time.time() - t
|
||||||
|
|
||||||
|
# Pre-compute query embedding and warm up search outside of timed section.
|
||||||
|
q_vec = compute_embeddings(
|
||||||
|
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||||
|
)
|
||||||
|
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||||
|
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||||
|
print("[DEBUG B] Warming up search...")
|
||||||
|
_ = _search(index_no_update, q_vec, 1)
|
||||||
|
|
||||||
|
# Worker B: timed search on the warmed index
|
||||||
|
search_time = 0.0
|
||||||
|
offline_elapsed = 0.0
|
||||||
|
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||||
|
|
||||||
|
def _worker_search():
|
||||||
|
nonlocal search_time, index_results
|
||||||
|
t = time.time()
|
||||||
|
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||||
|
search_time = time.time() - t
|
||||||
|
index_results = (distances, indices)
|
||||||
|
|
||||||
|
# Run two workers concurrently
|
||||||
|
t0 = time.time()
|
||||||
|
th1 = threading.Thread(target=_worker_emb)
|
||||||
|
th2 = threading.Thread(target=_worker_search)
|
||||||
|
th1.start()
|
||||||
|
th2.start()
|
||||||
|
th1.join()
|
||||||
|
th2.join()
|
||||||
|
offline_elapsed = time.time() - t0
|
||||||
|
|
||||||
|
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||||
|
offline_scores: list[tuple[int, float]] = []
|
||||||
|
if updates_embs_offline is not None:
|
||||||
|
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||||
|
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||||
|
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||||
|
for j in range(upd2.shape[0]):
|
||||||
|
if args.distance_metric in ("mips", "cosine"):
|
||||||
|
s = float(np.dot(q_vec[0], upd2[j]))
|
||||||
|
else:
|
||||||
|
diff = q_vec[0] - upd2[j]
|
||||||
|
s = -float(np.dot(diff, diff))
|
||||||
|
offline_scores.append((j, s))
|
||||||
|
|
||||||
|
merged_topk = (
|
||||||
|
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||||
|
if index_results
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||||
|
print(
|
||||||
|
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||||
|
)
|
||||||
|
if merged_topk:
|
||||||
|
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||||
|
print(f"Merged top-5 preview: {preview}")
|
||||||
|
# CSV row for B
|
||||||
|
if args.csv_path:
|
||||||
|
row_b = {
|
||||||
|
"run_id": run_id,
|
||||||
|
"scenario": "B",
|
||||||
|
"max_initial": args.max_initial,
|
||||||
|
"num_updates": args.num_updates,
|
||||||
|
"k": args.k,
|
||||||
|
"total_time_s": 0.0,
|
||||||
|
"add_total_s": 0.0,
|
||||||
|
"search_time_s": round(search_time, 6),
|
||||||
|
"emb_time_s": round(emb_time, 6),
|
||||||
|
"makespan_s": round(offline_elapsed, 6),
|
||||||
|
"model_name": args.model_name,
|
||||||
|
"embedding_mode": args.embedding_mode,
|
||||||
|
"distance_metric": args.distance_metric,
|
||||||
|
}
|
||||||
|
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||||
|
writer.writerow(row_b)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
server_manager_b.stop_server()
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n=== Summary ===")
|
||||||
|
msg_a = (
|
||||||
|
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||||
|
if args.only in ("A", "both")
|
||||||
|
else "A: skipped"
|
||||||
|
)
|
||||||
|
msg_b = (
|
||||||
|
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||||
|
if args.only in ("B", "both")
|
||||||
|
else "B: skipped"
|
||||||
|
)
|
||||||
|
print(msg_a + "\n" + msg_b)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
5
benchmarks/update/offline_vs_update.csv
Normal file
5
benchmarks/update/offline_vs_update.csv
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||||
|
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||||
|
645
benchmarks/update/plot_bench_results.py
Normal file
645
benchmarks/update/plot_bench_results.py
Normal file
@@ -0,0 +1,645 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Plot latency bars from the benchmark CSV produced by
|
||||||
|
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||||
|
|
||||||
|
If you also provide an offline_vs_update.csv via --csv-right
|
||||||
|
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||||
|
output a side-by-side figure:
|
||||||
|
- Left: ms/passage bars (four RNG scenarios).
|
||||||
|
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python benchmarks/update/plot_bench_results.py \
|
||||||
|
--csv benchmarks/update/bench_results.csv \
|
||||||
|
--out benchmarks/update/bench_latency_from_csv.png
|
||||||
|
|
||||||
|
The script selects the latest run_id in the CSV and plots four bars for
|
||||||
|
the default scenarios:
|
||||||
|
- baseline
|
||||||
|
- no_cache_baseline
|
||||||
|
- disable_forward_rng
|
||||||
|
- disable_forward_and_reverse_rng
|
||||||
|
|
||||||
|
If multiple rows exist per scenario for that run_id, the script averages
|
||||||
|
their latency_ms_per_passage values.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DEFAULT_SCENARIOS = [
|
||||||
|
"no_cache_baseline",
|
||||||
|
"baseline",
|
||||||
|
"disable_forward_rng",
|
||||||
|
"disable_forward_and_reverse_rng",
|
||||||
|
]
|
||||||
|
|
||||||
|
SCENARIO_LABELS = {
|
||||||
|
"baseline": "+ Cache",
|
||||||
|
"no_cache_baseline": "Naive \n Recompute",
|
||||||
|
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||||
|
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Paper-style colors and hatches for scenarios
|
||||||
|
SCENARIO_STYLES = {
|
||||||
|
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||||
|
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||||
|
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||||
|
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_latest_run(csv_path: Path):
|
||||||
|
rows = []
|
||||||
|
with csv_path.open("r", encoding="utf-8") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
rows.append(row)
|
||||||
|
if not rows:
|
||||||
|
raise SystemExit("CSV is empty: no rows to plot")
|
||||||
|
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||||
|
run_ids = [r.get("run_id", "") for r in rows]
|
||||||
|
latest = max(run_ids)
|
||||||
|
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||||
|
if not latest_rows:
|
||||||
|
# Fallback: take last 4 rows
|
||||||
|
latest_rows = rows[-4:]
|
||||||
|
latest = latest_rows[-1].get("run_id", "unknown")
|
||||||
|
return latest, latest_rows
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_latency(rows):
|
||||||
|
acc = defaultdict(list)
|
||||||
|
for r in rows:
|
||||||
|
sc = r.get("scenario", "")
|
||||||
|
try:
|
||||||
|
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
acc[sc].append(val)
|
||||||
|
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_cap(values: list[float]) -> float | None:
|
||||||
|
if not values:
|
||||||
|
return None
|
||||||
|
sorted_vals = sorted(values, reverse=True)
|
||||||
|
if len(sorted_vals) < 2:
|
||||||
|
return None
|
||||||
|
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||||
|
if second <= 0:
|
||||||
|
return None
|
||||||
|
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||||
|
if max_v >= 2.5 * second:
|
||||||
|
return second * 1.1
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||||
|
# Draw small diagonal ticks near left/right to signal cap
|
||||||
|
x0, x1 = rel_x0, rel_x1
|
||||||
|
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_ms(v: float) -> str:
|
||||||
|
if v >= 1000:
|
||||||
|
return f"{v / 1000:.1f}k"
|
||||||
|
return f"{v:.1f}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.rcParams["font.family"] = "Helvetica"
|
||||||
|
plt.rcParams["ytick.direction"] = "in"
|
||||||
|
plt.rcParams["hatch.linewidth"] = 1.5
|
||||||
|
plt.rcParams["font.weight"] = "bold"
|
||||||
|
plt.rcParams["axes.labelweight"] = "bold"
|
||||||
|
plt.rcParams["text.usetex"] = True
|
||||||
|
|
||||||
|
ap = argparse.ArgumentParser(description=__doc__)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/bench_results.csv"),
|
||||||
|
help="Path to results CSV (defaults to bench_results.csv)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--out",
|
||||||
|
type=Path,
|
||||||
|
default=Path("add_ablation.pdf"),
|
||||||
|
help="Output image path",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--csv-right",
|
||||||
|
type=Path,
|
||||||
|
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||||
|
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--no-auto-cap",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--broken-y",
|
||||||
|
action="store_true",
|
||||||
|
default=True,
|
||||||
|
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--lower-cap-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||||
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"--upper-start-y",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||||
|
)
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
latest_run, latest_rows = load_latest_run(args.csv)
|
||||||
|
avg = aggregate_latency(latest_rows)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
except Exception as e:
|
||||||
|
raise SystemExit(f"matplotlib not available: {e}")
|
||||||
|
|
||||||
|
scenarios = DEFAULT_SCENARIOS
|
||||||
|
values = [avg.get(name, 0.0) for name in scenarios]
|
||||||
|
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||||
|
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||||
|
|
||||||
|
# If right CSV is provided, build side-by-side figure
|
||||||
|
if args.csv_right is not None:
|
||||||
|
try:
|
||||||
|
right_rows_all = []
|
||||||
|
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||||
|
rreader = csv.DictReader(f)
|
||||||
|
right_rows_all = list(rreader)
|
||||||
|
if right_rows_all:
|
||||||
|
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||||
|
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||||
|
else:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
except Exception:
|
||||||
|
r_latest = None
|
||||||
|
right_rows = []
|
||||||
|
|
||||||
|
a_total = 0.0
|
||||||
|
b_makespan = 0.0
|
||||||
|
for r in right_rows:
|
||||||
|
sc = (r.get("scenario", "") or "").strip().upper()
|
||||||
|
if sc == "A":
|
||||||
|
try:
|
||||||
|
a_total = float(r.get("total_time_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif sc == "B":
|
||||||
|
try:
|
||||||
|
b_makespan = float(r.get("makespan_s", 0.0))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import gridspec
|
||||||
|
|
||||||
|
# Left subplot (reuse current style, with optional cap)
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
# Use broken axis for left subplot
|
||||||
|
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||||
|
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||||
|
gs = gridspec.GridSpec(
|
||||||
|
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||||
|
)
|
||||||
|
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||||
|
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||||
|
ax_right = fig.add_subplot(gs[:, 1])
|
||||||
|
|
||||||
|
# Determine break points
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = (
|
||||||
|
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||||
|
) # Increased to show more range
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.5, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = (
|
||||||
|
max(values) * 1.90 if values else 1.0
|
||||||
|
) # Increase headroom to 1.90 for text label and tick range
|
||||||
|
|
||||||
|
# Draw bars on both axes
|
||||||
|
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Set limits
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_left_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values (convert ms to s)
|
||||||
|
values_s = [v / 1000.0 for v in values]
|
||||||
|
lower_cap_s = lower_cap / 1000.0
|
||||||
|
upper_start_s = upper_start / 1000.0
|
||||||
|
ymax_s = ymax / 1000.0
|
||||||
|
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||||
|
ax_left_bottom.clear()
|
||||||
|
ax_left_top.clear()
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||||
|
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||||
|
# Draw in bottom axis for all bars
|
||||||
|
ax_left_bottom.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||||
|
if v > upper_start_s:
|
||||||
|
ax_left_top.bar(
|
||||||
|
i,
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||||
|
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||||
|
|
||||||
|
for i, v in enumerate(values_s):
|
||||||
|
if v <= lower_cap_s:
|
||||||
|
ax_left_bottom.text(
|
||||||
|
i,
|
||||||
|
v + lower_cap_s * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left_top.text(
|
||||||
|
i,
|
||||||
|
v + (ymax_s - upper_start_s) * 0.02,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hide spines between axes
|
||||||
|
ax_left_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_left_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_left_top.tick_params(
|
||||||
|
labeltop=False, labelbottom=False, bottom=False
|
||||||
|
) # Hide tick marks
|
||||||
|
ax_left_bottom.xaxis.tick_bottom()
|
||||||
|
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||||
|
|
||||||
|
# Draw break marks (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_left_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||||
|
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||||
|
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||||
|
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||||
|
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||||
|
|
||||||
|
ax_left_bottom.set_xticks(x)
|
||||||
|
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||||
|
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match bar width with right subplot
|
||||||
|
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||||
|
ax_left_top.set_xlim(-0.6, 3.6)
|
||||||
|
|
||||||
|
ax_left = ax_left_bottom # for compatibility
|
||||||
|
else:
|
||||||
|
# Regular side-by-side layout
|
||||||
|
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||||
|
if val > cap:
|
||||||
|
bars[i].set_hatch("//")
|
||||||
|
ax_left.text(
|
||||||
|
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_left.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
_fmt_ms(val),
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax_left.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax_left, y=0.98)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
else:
|
||||||
|
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
ax_left.set_xticks(x)
|
||||||
|
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax_left.set_ylabel("Latency (ms per passage)")
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
ax_left.set_title(
|
||||||
|
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right subplot (A vs B, seconds) - paper style
|
||||||
|
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||||
|
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||||
|
r_styles = [
|
||||||
|
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||||
|
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||||
|
]
|
||||||
|
# 2 bars, centered with proper spacing
|
||||||
|
xr = [0, 1]
|
||||||
|
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||||
|
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||||
|
ax_right.bar(
|
||||||
|
xr[i],
|
||||||
|
v,
|
||||||
|
width=bar_width,
|
||||||
|
color="white",
|
||||||
|
edgecolor=style["edgecolor"],
|
||||||
|
hatch=style["hatch"],
|
||||||
|
linewidth=1.2,
|
||||||
|
)
|
||||||
|
for i, v in enumerate(r_values):
|
||||||
|
max_v = max(r_values) if r_values else 1.0
|
||||||
|
offset = max(0.0002, 0.02 * max_v)
|
||||||
|
ax_right.text(
|
||||||
|
xr[i],
|
||||||
|
v + offset,
|
||||||
|
f"{v:.2f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=8,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_right.set_xticks(xr)
|
||||||
|
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||||
|
# Don't set ylabel here - will use fig.text for alignment
|
||||||
|
ax_right.tick_params(axis="y", labelsize=10)
|
||||||
|
# Add subtle grid for better readability
|
||||||
|
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||||
|
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||||
|
|
||||||
|
# Set x-axis limits to match left subplot's bar width visually
|
||||||
|
# Accounting for width_ratios=[1.5, 1]:
|
||||||
|
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||||
|
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# Right: 2 bars, need same visual width
|
||||||
|
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||||
|
# range_right = 4.2 / 1.5 = 2.8
|
||||||
|
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||||
|
ax_right.set_xlim(-0.9, 1.9)
|
||||||
|
|
||||||
|
# Set y-axis limit with headroom for text labels
|
||||||
|
if r_values:
|
||||||
|
max_v = max(r_values)
|
||||||
|
ax_right.set_ylim(0, max_v * 1.15)
|
||||||
|
|
||||||
|
# Format y-axis to avoid scientific notation
|
||||||
|
ax_right.ticklabel_format(style="plain", axis="y")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Add aligned ylabels using fig.text (after tight_layout)
|
||||||
|
# Get the vertical center of the entire figure
|
||||||
|
fig_center_y = 0.5
|
||||||
|
# Left ylabel - closer to left plot
|
||||||
|
left_x = 0.05
|
||||||
|
fig.text(
|
||||||
|
left_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
# Right ylabel - closer to right plot
|
||||||
|
right_bbox = ax_right.get_position()
|
||||||
|
right_x = right_bbox.x0 - 0.07
|
||||||
|
fig.text(
|
||||||
|
right_x,
|
||||||
|
fig_center_y,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Broken-Y mode
|
||||||
|
if args.broken_y:
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
sharex=True,
|
||||||
|
figsize=(7.5, 6.75),
|
||||||
|
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine default breaks from second-highest
|
||||||
|
s = sorted(values, reverse=True)
|
||||||
|
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||||
|
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||||
|
upper_start = (
|
||||||
|
args.upper_start_y
|
||||||
|
if args.upper_start_y is not None
|
||||||
|
else max(second * 1.2, lower_cap * 1.02)
|
||||||
|
)
|
||||||
|
ymax = max(values) * 1.10 if values else 1.0
|
||||||
|
|
||||||
|
x = list(range(len(labels)))
|
||||||
|
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
ax_bottom.set_ylim(0, lower_cap)
|
||||||
|
ax_top.set_ylim(upper_start, ymax)
|
||||||
|
|
||||||
|
# Annotate values
|
||||||
|
for i, v in enumerate(values):
|
||||||
|
if v <= lower_cap:
|
||||||
|
ax_bottom.text(
|
||||||
|
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||||
|
|
||||||
|
# Hide spines between axes and draw diagonal break marks
|
||||||
|
ax_top.spines["bottom"].set_visible(False)
|
||||||
|
ax_bottom.spines["top"].set_visible(False)
|
||||||
|
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||||
|
ax_bottom.xaxis.tick_bottom()
|
||||||
|
|
||||||
|
# Diagonal lines at the break (matching paper_fig.py style)
|
||||||
|
d = 0.015
|
||||||
|
kwargs = {
|
||||||
|
"transform": ax_top.transAxes,
|
||||||
|
"color": "k",
|
||||||
|
"clip_on": False,
|
||||||
|
"linewidth": 0.8,
|
||||||
|
"zorder": 10,
|
||||||
|
}
|
||||||
|
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||||
|
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||||
|
kwargs.update({"transform": ax_bottom.transAxes})
|
||||||
|
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||||
|
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||||
|
|
||||||
|
ax_bottom.set_xticks(x)
|
||||||
|
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||||
|
ax = ax_bottom # for labeling below
|
||||||
|
else:
|
||||||
|
cap = args.cap_y
|
||||||
|
if cap is None and not args.no_auto_cap:
|
||||||
|
cap = _auto_cap(values)
|
||||||
|
|
||||||
|
plt.figure(figsize=(5.4, 3.15))
|
||||||
|
ax = plt.gca()
|
||||||
|
|
||||||
|
if cap is not None:
|
||||||
|
show_vals = [min(v, cap) for v in values]
|
||||||
|
bars = []
|
||||||
|
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||||
|
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||||
|
bars.append(bar[0])
|
||||||
|
# Hatch and annotate when capped
|
||||||
|
if val > cap:
|
||||||
|
bars[-1].set_hatch("//")
|
||||||
|
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||||
|
else:
|
||||||
|
ax.text(
|
||||||
|
i,
|
||||||
|
show + max(1.0, 0.01 * (cap or show)),
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=9,
|
||||||
|
)
|
||||||
|
ax.set_ylim(0, cap * 1.10)
|
||||||
|
_add_break_marker(ax, y=0.98)
|
||||||
|
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||||
|
v > cap for v in values
|
||||||
|
) else None
|
||||||
|
ax.set_xticks(range(len(labels)))
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
else:
|
||||||
|
ax.bar(labels, values, color=colors[: len(labels)])
|
||||||
|
for idx, val in enumerate(values):
|
||||||
|
ax.text(
|
||||||
|
idx,
|
||||||
|
val + 1.0,
|
||||||
|
f"{_fmt_ms(val)}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontsize=10,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||||
|
# Try to extract some context for title
|
||||||
|
max_initial = latest_rows[0].get("max_initial", "?")
|
||||||
|
max_updates = latest_rows[0].get("max_updates", "?")
|
||||||
|
|
||||||
|
if args.broken_y:
|
||||||
|
fig.text(
|
||||||
|
0.02,
|
||||||
|
0.5,
|
||||||
|
"Latency (s)",
|
||||||
|
va="center",
|
||||||
|
rotation="vertical",
|
||||||
|
fontsize=11,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
fig.suptitle(
|
||||||
|
"Add Operation Latency",
|
||||||
|
fontsize=11,
|
||||||
|
y=0.98,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||||
|
else:
|
||||||
|
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||||
|
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
# Also save PDF for paper
|
||||||
|
pdf_out = args.out.with_suffix(".pdf")
|
||||||
|
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||||
|
print(f"Saved: {args.out}")
|
||||||
|
print(f"Saved: {pdf_out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
82
data/.gitattributes
vendored
82
data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
|||||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.model filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
||||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - uncompressed
|
|
||||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Audio files - compressed
|
|
||||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - uncompressed
|
|
||||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.png filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Image files - compressed
|
|
||||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
|
||||||
# Video files - compressed
|
|
||||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
||||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
|
||||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
|
||||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
|
||||||
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
|
|||||||
including how to make donations to the Project Gutenberg Literary
|
including how to make donations to the Project Gutenberg Literary
|
||||||
Archive Foundation, how to help produce our new eBooks, and how to
|
Archive Foundation, how to help produce our new eBooks, and how to
|
||||||
subscribe to our email newsletter to hear about new eBooks.
|
subscribe to our email newsletter to hear about new eBooks.
|
||||||
|
|
||||||
|
|
||||||
277
demo.ipynb
277
demo.ipynb
@@ -4,7 +4,11 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Quick Start in 30s"
|
"# Quick Start \n",
|
||||||
|
"\n",
|
||||||
|
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||||
|
"\n",
|
||||||
|
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -13,8 +17,25 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# install this if you areusing colab\n",
|
"# install this if you are using colab\n",
|
||||||
"! pip install leann"
|
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||||
|
"! uv pip install leann --no-deps\n",
|
||||||
|
"# For Colab environment, we need to set some environment variables\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||||
|
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -26,91 +47,21 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO: Registering backend 'hnsw'\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
|
||||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
|
||||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
|
||||||
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
|
|
||||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
|
|
||||||
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
|
||||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"M: 64 for level: 0\n",
|
|
||||||
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
|
|
||||||
"[0.00s] Reading Index HNSW header...\n",
|
|
||||||
"[0.00s] Header read: d=768, ntotal=5\n",
|
|
||||||
"[0.00s] Reading HNSW struct vectors...\n",
|
|
||||||
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
|
||||||
"[0.00s] Read assign_probas (6)\n",
|
|
||||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
|
||||||
"[0.11s] Read cum_nneighbor_per_level (7)\n",
|
|
||||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
|
||||||
"[0.21s] Read levels (5)\n",
|
|
||||||
"[0.30s] Probing for compact storage flag...\n",
|
|
||||||
"[0.30s] Found compact flag: False\n",
|
|
||||||
"[0.30s] Compact flag is False, reading original format...\n",
|
|
||||||
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
|
|
||||||
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
|
|
||||||
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
|
||||||
"[0.30s] Read offsets (6)\n",
|
|
||||||
"[0.40s] Attempting to read neighbors vector...\n",
|
|
||||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
|
||||||
"[0.40s] Read neighbors (320)\n",
|
|
||||||
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
|
|
||||||
"[0.50s] Checking for storage data...\n",
|
|
||||||
"[0.50s] Found storage fourcc: 49467849.\n",
|
|
||||||
"[0.50s] Converting to CSR format...\n",
|
|
||||||
"[0.50s] Conversion loop finished. \n",
|
|
||||||
"[0.50s] Running validation checks...\n",
|
|
||||||
" Checking total valid neighbor count...\n",
|
|
||||||
" OK: Total valid neighbors = 20\n",
|
|
||||||
" Checking final pointer indices...\n",
|
|
||||||
" OK: Final pointers match data size.\n",
|
|
||||||
"[0.50s] Deleting original neighbors and offsets arrays...\n",
|
|
||||||
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
|
||||||
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
|
||||||
" Pruning embeddings: Writing NULL storage marker.\n",
|
|
||||||
"[0.69s] Conversion complete.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
|
||||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannBuilder\n",
|
"from leann.api import LeannBuilder\n",
|
||||||
"\n",
|
"\n",
|
||||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||||
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||||
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
|
"builder.add_text(\n",
|
||||||
|
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||||
|
")\n",
|
||||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||||
"builder.build_index(\"knowledge.leann\")"
|
"builder.build_index(INDEX_PATH)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -122,97 +73,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
|
||||||
"INFO:leann.api: Query: 'programming languages'\n",
|
|
||||||
"INFO:leann.api: Top_k: 2\n",
|
|
||||||
"INFO:leann.api: Additional kwargs: {}\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
|
|
||||||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
|
||||||
"To disable this warning, you can either:\n",
|
|
||||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
|
||||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
|
||||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
|
||||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
|
||||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
|
||||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
|
||||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
|
||||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
|
||||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
|
||||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
|
||||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
|
||||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
|
||||||
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
|
||||||
"INFO: Registering backend 'hnsw'\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
|
|
||||||
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
|
|
||||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
|
||||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
|
||||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
|
||||||
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
|
|
||||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
|
||||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
|
||||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
|
||||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
|
||||||
"INFO:leann.api: Final enriched results: 2 passages\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
|
|
||||||
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannSearcher\n",
|
"from leann.api import LeannSearcher\n",
|
||||||
"\n",
|
"\n",
|
||||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||||
"results"
|
"results"
|
||||||
]
|
]
|
||||||
@@ -228,79 +95,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
|
|
||||||
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
|
|
||||||
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
|
||||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
|
||||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
|
||||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
|
||||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
|
||||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
|
||||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
|
||||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
|
||||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
|
||||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
|
||||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
|
||||||
"INFO: Skipping external storage loading, since is_recompute is true.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
|
||||||
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
|
|
||||||
"INFO:leann.api: Top_k: 2\n",
|
|
||||||
"INFO:leann.api: Additional kwargs: {}\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
|
||||||
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
|
||||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
|
||||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
|
||||||
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
|
|
||||||
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
|
|
||||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
|
||||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
|
||||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
|
||||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
|
||||||
"INFO:leann.api: Final enriched results: 2 passages\n",
|
|
||||||
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 8,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"from leann.api import LeannChat\n",
|
"from leann.api import LeannChat\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -309,11 +104,11 @@
|
|||||||
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||||
"response = chat.ask(\n",
|
"response = chat.ask(\n",
|
||||||
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||||
" top_k=2,\n",
|
" top_k=2,\n",
|
||||||
" llm_kwargs={\"max_tokens\": 128}\n",
|
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||||
")\n",
|
")\n",
|
||||||
"response"
|
"response"
|
||||||
]
|
]
|
||||||
|
|||||||
223
docs/CONTRIBUTING.md
Normal file
223
docs/CONTRIBUTING.md
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# 🤝 Contributing
|
||||||
|
|
||||||
|
We welcome contributions! Leann is built by the community, for the community.
|
||||||
|
|
||||||
|
## Ways to Contribute
|
||||||
|
|
||||||
|
- 🐛 **Bug Reports**: Found an issue? Let us know!
|
||||||
|
- 💡 **Feature Requests**: Have an idea? We'd love to hear it!
|
||||||
|
- 🔧 **Code Contributions**: PRs welcome for all skill levels
|
||||||
|
- 📖 **Documentation**: Help make Leann more accessible
|
||||||
|
- 🧪 **Benchmarks**: Share your performance results
|
||||||
|
|
||||||
|
## 🚀 Development Setup
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
1. **Install uv** (fast Python package installer):
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Clone the repository**:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/LEANN-RAG/LEANN-RAG.git
|
||||||
|
cd LEANN-RAG
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Install system dependencies**:
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
brew install llvm libomp boost protobuf zeromq pkgconf
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ubuntu/Debian:**
|
||||||
|
```bash
|
||||||
|
sudo apt-get install libomp-dev libboost-all-dev protobuf-compiler \
|
||||||
|
libabsl-dev libmkl-full-dev libaio-dev libzmq3-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Build from source**:
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv sync
|
||||||
|
|
||||||
|
# Ubuntu/Debian
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔨 Pre-commit Hooks
|
||||||
|
|
||||||
|
We use pre-commit hooks to ensure code quality and consistency. This runs automatically before each commit.
|
||||||
|
|
||||||
|
### Setup Pre-commit
|
||||||
|
|
||||||
|
1. **Install pre-commit tools**:
|
||||||
|
```bash
|
||||||
|
uv sync lint
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install the git hooks**:
|
||||||
|
```bash
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Run pre-commit manually** (optional):
|
||||||
|
```bash
|
||||||
|
uv run pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pre-commit Checks
|
||||||
|
|
||||||
|
Our pre-commit configuration includes:
|
||||||
|
- **Trailing whitespace removal**
|
||||||
|
- **End-of-file fixing**
|
||||||
|
- **YAML validation**
|
||||||
|
- **Large file prevention**
|
||||||
|
- **Merge conflict detection**
|
||||||
|
- **Debug statement detection**
|
||||||
|
- **Code formatting with ruff**
|
||||||
|
- **Code linting with ruff**
|
||||||
|
|
||||||
|
## 🧪 Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install test tools only (no project runtime)
|
||||||
|
uv sync --group test
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
uv run pytest
|
||||||
|
|
||||||
|
# Run specific test file
|
||||||
|
uv run pytest test/test_filename.py
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
uv run pytest --cov=leann
|
||||||
|
```
|
||||||
|
|
||||||
|
### Writing Tests
|
||||||
|
|
||||||
|
- Place tests in the `test/` directory
|
||||||
|
- Follow the naming convention `test_*.py`
|
||||||
|
- Use descriptive test names that explain what's being tested
|
||||||
|
- Include both positive and negative test cases
|
||||||
|
|
||||||
|
## 📝 Code Style
|
||||||
|
|
||||||
|
We use `ruff` for both linting and formatting to ensure consistent code style.
|
||||||
|
|
||||||
|
### Format Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Format all files
|
||||||
|
ruff format
|
||||||
|
|
||||||
|
# Check formatting without changing files
|
||||||
|
ruff format --check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Lint Your Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run linter with auto-fix
|
||||||
|
ruff check --fix
|
||||||
|
|
||||||
|
# Just check without fixing
|
||||||
|
ruff check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Style Guidelines
|
||||||
|
|
||||||
|
- Follow PEP 8 conventions
|
||||||
|
- Use descriptive variable names
|
||||||
|
- Add type hints where appropriate
|
||||||
|
- Write docstrings for all public functions and classes
|
||||||
|
- Keep functions focused and single-purpose
|
||||||
|
|
||||||
|
## 🚦 CI/CD
|
||||||
|
|
||||||
|
Our CI pipeline runs automatically on all pull requests. It includes:
|
||||||
|
|
||||||
|
1. **Linting and Formatting**: Ensures code follows our style guidelines
|
||||||
|
2. **Multi-platform builds**: Tests on Ubuntu and macOS
|
||||||
|
3. **Python version matrix**: Tests on Python 3.9-3.13
|
||||||
|
4. **Wheel building**: Ensures packages can be built and distributed
|
||||||
|
|
||||||
|
### CI Commands
|
||||||
|
|
||||||
|
The CI uses the same commands as pre-commit to ensure consistency:
|
||||||
|
```bash
|
||||||
|
# Linting
|
||||||
|
ruff check .
|
||||||
|
|
||||||
|
# Format checking
|
||||||
|
ruff format --check .
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure your code passes these checks locally before pushing!
|
||||||
|
|
||||||
|
## 🔄 Pull Request Process
|
||||||
|
|
||||||
|
1. **Fork the repository** and create your branch from `main`:
|
||||||
|
```bash
|
||||||
|
git checkout -b feature/your-feature-name
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Make your changes**:
|
||||||
|
- Write clean, documented code
|
||||||
|
- Add tests for new functionality
|
||||||
|
- Update documentation as needed
|
||||||
|
|
||||||
|
3. **Run pre-commit checks**:
|
||||||
|
```bash
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Test your changes**:
|
||||||
|
```bash
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
|
||||||
|
5. **Commit with descriptive messages**:
|
||||||
|
```bash
|
||||||
|
git commit -m "feat: add new search algorithm"
|
||||||
|
```
|
||||||
|
|
||||||
|
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||||
|
- `feat:` for new features
|
||||||
|
- `fix:` for bug fixes
|
||||||
|
- `docs:` for documentation changes
|
||||||
|
- `test:` for test additions/changes
|
||||||
|
- `refactor:` for code refactoring
|
||||||
|
- `perf:` for performance improvements
|
||||||
|
|
||||||
|
6. **Push and create a pull request**:
|
||||||
|
- Provide a clear description of your changes
|
||||||
|
- Reference any related issues
|
||||||
|
- Include examples or screenshots if applicable
|
||||||
|
|
||||||
|
## 📚 Documentation
|
||||||
|
|
||||||
|
When adding new features or making significant changes:
|
||||||
|
|
||||||
|
1. Update relevant documentation in `/docs`
|
||||||
|
2. Add docstrings to new functions/classes
|
||||||
|
3. Update README.md if needed
|
||||||
|
4. Include usage examples
|
||||||
|
|
||||||
|
## 🤔 Getting Help
|
||||||
|
|
||||||
|
- **Discord**: Join our community for discussions
|
||||||
|
- **Issues**: Check existing issues or create a new one
|
||||||
|
- **Discussions**: For general questions and ideas
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
By contributing, you agree that your contributions will be licensed under the same license as the project (MIT).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Thank you for contributing to LEANN! Every contribution, no matter how small, helps make the project better for everyone. 🌟
|
||||||
101
docs/RELEASE.md
101
docs/RELEASE.md
@@ -1,93 +1,22 @@
|
|||||||
# Release Guide
|
# Release Guide
|
||||||
|
|
||||||
## 📋 Prerequisites
|
## Setup (One-time)
|
||||||
|
|
||||||
Before releasing, ensure:
|
Add `PYPI_API_TOKEN` to GitHub Secrets:
|
||||||
1. ✅ All code changes are committed and pushed
|
1. Get token: https://pypi.org/manage/account/token/
|
||||||
2. ✅ CI has passed on the latest commit (check [Actions](https://github.com/yichuan-w/LEANN/actions/workflows/ci.yml))
|
2. Add to secrets: Settings → Secrets → Actions → `PYPI_API_TOKEN`
|
||||||
3. ✅ You have determined the new version number
|
|
||||||
|
|
||||||
### Optional: TestPyPI Configuration
|
## Release (One-click)
|
||||||
|
|
||||||
To enable TestPyPI testing (recommended but not required):
|
1. Go to: https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml
|
||||||
1. Get a TestPyPI API token from https://test.pypi.org/manage/account/token/
|
2. Click "Run workflow"
|
||||||
2. Add it to repository secrets: Settings → Secrets → Actions → New repository secret
|
3. Enter version: `0.1.2`
|
||||||
- Name: `TEST_PYPI_API_TOKEN`
|
4. Click green "Run workflow" button
|
||||||
- Value: Your TestPyPI token (starts with `pypi-`)
|
|
||||||
|
|
||||||
**Note**: TestPyPI testing is optional. If not configured, the release will skip TestPyPI and proceed.
|
That's it! The workflow will automatically:
|
||||||
|
- ✅ Update version in all packages
|
||||||
|
- ✅ Build all packages
|
||||||
|
- ✅ Publish to PyPI
|
||||||
|
- ✅ Create GitHub tag and release
|
||||||
|
|
||||||
## 🚀 Recommended: Manual Release Workflow
|
Check progress: https://github.com/yichuan-w/LEANN/actions
|
||||||
|
|
||||||
### Via GitHub UI (Most Reliable)
|
|
||||||
|
|
||||||
1. **Verify CI Status**: Check that the latest commit has a green checkmark ✅
|
|
||||||
2. Go to [Actions → Manual Release](https://github.com/yichuan-w/LEANN/actions/workflows/release-manual.yml)
|
|
||||||
3. Click "Run workflow"
|
|
||||||
4. Enter version (e.g., `0.1.1`)
|
|
||||||
5. Toggle "Test on TestPyPI first" if desired
|
|
||||||
6. Click "Run workflow"
|
|
||||||
|
|
||||||
**What happens:**
|
|
||||||
- ✅ Validates version format
|
|
||||||
- ✅ Downloads pre-built packages from CI (no rebuild needed!)
|
|
||||||
- ✅ Updates all package versions
|
|
||||||
- ✅ Optionally tests on TestPyPI
|
|
||||||
- ✅ Creates tag and GitHub release
|
|
||||||
- ✅ Automatically triggers PyPI publish
|
|
||||||
|
|
||||||
### Via Command Line
|
|
||||||
|
|
||||||
```bash
|
|
||||||
gh workflow run release-manual.yml -f version=0.1.1 -f test_pypi=true
|
|
||||||
```
|
|
||||||
|
|
||||||
## ⚡ Quick Release (One-Line)
|
|
||||||
|
|
||||||
For experienced users who want the fastest path:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./scripts/release.sh 0.1.1
|
|
||||||
```
|
|
||||||
|
|
||||||
This script will:
|
|
||||||
1. Update all package versions
|
|
||||||
2. Commit and push changes
|
|
||||||
3. Create GitHub release
|
|
||||||
4. CI automatically builds and publishes to PyPI
|
|
||||||
|
|
||||||
⚠️ **Note**: If CI fails, you'll need to manually fix and re-tag
|
|
||||||
|
|
||||||
## Manual Testing Before Release
|
|
||||||
|
|
||||||
For testing specific packages locally (especially DiskANN on macOS):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Build specific package locally
|
|
||||||
./scripts/build_and_test.sh diskann # or hnsw, core, meta, all
|
|
||||||
|
|
||||||
# Test installation in a clean environment
|
|
||||||
python -m venv test_env
|
|
||||||
source test_env/bin/activate
|
|
||||||
pip install packages/*/dist/*.whl
|
|
||||||
|
|
||||||
# Upload to Test PyPI (optional)
|
|
||||||
./scripts/upload_to_pypi.sh test
|
|
||||||
|
|
||||||
# Upload to Production PyPI (use with caution)
|
|
||||||
./scripts/upload_to_pypi.sh prod
|
|
||||||
```
|
|
||||||
|
|
||||||
## First-time setup
|
|
||||||
|
|
||||||
1. Install GitHub CLI:
|
|
||||||
```bash
|
|
||||||
brew install gh
|
|
||||||
gh auth login
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Set PyPI token in GitHub:
|
|
||||||
```bash
|
|
||||||
gh secret set PYPI_API_TOKEN
|
|
||||||
# Paste your PyPI token when prompted
|
|
||||||
```
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user