Compare commits
194 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
198044d033 | ||
|
|
a2e5f5294b | ||
|
|
8a2ea37871 | ||
|
|
7ddb4772c0 | ||
|
|
a1c21adbce | ||
|
|
d1b3c93a5a | ||
|
|
a6ee95b18a | ||
|
|
17cbd07b25 | ||
|
|
3629ccf8f7 | ||
|
|
0175bc9c20 | ||
|
|
af47dfdde7 | ||
|
|
f13bd02fbd | ||
|
|
a0bbf831db | ||
|
|
86287d8832 | ||
|
|
76cc798e3e | ||
|
|
d599566fd7 | ||
|
|
00770aebbb | ||
|
|
e268392d5b | ||
|
|
eb909ccec5 | ||
|
|
13beb98164 | ||
|
|
969f514564 | ||
|
|
1ef9cba7de | ||
|
|
a63550944b | ||
|
|
97493a2896 | ||
|
|
f7d2dc6e7c | ||
|
|
ea86b283cb | ||
|
|
e7519bceaa | ||
|
|
abf0b2c676 | ||
|
|
3c4785bb63 | ||
|
|
930b79cc98 | ||
|
|
9b7353f336 | ||
|
|
9dd0e0b26f | ||
|
|
3766ad1fd2 | ||
|
|
c3aceed1e0 | ||
|
|
dc6c9f696e | ||
|
|
2406c41eef | ||
|
|
d4f5f2896f | ||
|
|
366984e92e | ||
|
|
64b92a04a7 | ||
|
|
a85d0ad4a7 | ||
|
|
dbb5f4d352 | ||
|
|
f180b83589 | ||
|
|
abf312d998 | ||
|
|
ab251ab751 | ||
|
|
28085f6f04 | ||
|
|
6495833887 | ||
|
|
5543b3c5f7 | ||
|
|
a99983b3d9 | ||
|
|
36482e016c | ||
|
|
32967daf81 | ||
|
|
b4bb8dec75 | ||
|
|
5ba9cf6442 | ||
|
|
1484406a8d | ||
|
|
761ec1f0ac | ||
|
|
4808afc686 | ||
|
|
0bba4b2157 | ||
|
|
e67b5f44fa | ||
|
|
658bce47ef | ||
|
|
6b399ad8d2 | ||
|
|
16f35aa067 | ||
|
|
ab9c6bd69e | ||
|
|
e2b37914ce | ||
|
|
e588100674 | ||
|
|
fecee94af1 | ||
|
|
01475c10a0 | ||
|
|
c8aa063f48 | ||
|
|
576beb13db | ||
|
|
63c7b0c8a3 | ||
|
|
ec889f7ef4 | ||
|
|
322e5c162d | ||
|
|
edde0cdeb2 | ||
|
|
db7ba27ff6 | ||
|
|
5f7806e16f | ||
|
|
d034e2195b | ||
|
|
43894ff605 | ||
|
|
10311cc611 | ||
|
|
ad0d2faabc | ||
|
|
e93c0dec6f | ||
|
|
c5a29f849a | ||
|
|
3b8dc6368e | ||
|
|
e309f292de | ||
|
|
0d9f92ea0f | ||
|
|
b0b353d279 | ||
|
|
4dffdfedbe | ||
|
|
d41e467df9 | ||
|
|
4ca0489cb1 | ||
|
|
e83a671918 | ||
|
|
4e5b73ce7b | ||
|
|
31b4973141 | ||
|
|
dde2221513 | ||
|
|
6d11e86e71 | ||
|
|
13bb561aad | ||
|
|
0174ba5571 | ||
|
|
03af82d695 | ||
|
|
738f1dbab8 | ||
|
|
37d990d51c | ||
|
|
a6f07a54f1 | ||
|
|
46905e0687 | ||
|
|
838ade231e | ||
|
|
da6540decd | ||
|
|
39e18a7c11 | ||
|
|
6bde28584b | ||
|
|
f62632c41f | ||
|
|
27708243ca | ||
|
|
9a1e4652ca | ||
|
|
14e84d9e2d | ||
|
|
2dcfca19ff | ||
|
|
bee2167ee3 | ||
|
|
ef980d70b3 | ||
|
|
db3c63c441 | ||
|
|
00eeadb9dd | ||
|
|
42c8370709 | ||
|
|
fafdf8fcbe | ||
|
|
21f7d8e031 | ||
|
|
46565b9249 | ||
|
|
3dad76126a | ||
|
|
18e28bda32 | ||
|
|
609fa62fd5 | ||
|
|
eab13434ef | ||
|
|
b2390ccc14 | ||
|
|
e8fca2c84a | ||
|
|
790ae14f69 | ||
|
|
ac363072e6 | ||
|
|
93465af46c | ||
|
|
792ece67dc | ||
|
|
239e35e2e6 | ||
|
|
2fac0c6fbf | ||
|
|
9801aa581b | ||
|
|
5e97916608 | ||
|
|
8b9c2be8c9 | ||
|
|
3ff5aac8e0 | ||
|
|
67fef60466 | ||
|
|
b6ab6f1993 | ||
|
|
9f2e82a838 | ||
|
|
0b2b799d5a | ||
|
|
0f790fbbd9 | ||
|
|
387ae21eba | ||
|
|
3cc329c3e7 | ||
|
|
5567302316 | ||
|
|
075d4bd167 | ||
|
|
e4bcc76f88 | ||
|
|
710e83b1fd | ||
|
|
c96d653072 | ||
|
|
8b22d2b5d3 | ||
|
|
4cb544ee38 | ||
|
|
f94ce63d51 | ||
|
|
4271ff9d84 | ||
|
|
0d448c4a41 | ||
|
|
af5599e33c | ||
|
|
efdf6d917a | ||
|
|
dd71ac8d71 | ||
|
|
8bee1d4100 | ||
|
|
33521d6d00 | ||
|
|
8899734952 | ||
|
|
54df6310c5 | ||
|
|
19bcc07814 | ||
|
|
8356e3c668 | ||
|
|
08eac5c821 | ||
|
|
4671ed9b36 | ||
|
|
055c086398 | ||
|
|
d505dcc5e3 | ||
|
|
261006c36a | ||
|
|
b2eba23e21 | ||
|
|
e9ee687472 | ||
|
|
6f5d5e4a77 | ||
|
|
5c8921673a | ||
|
|
e9d2d420bd | ||
|
|
ebabfad066 | ||
|
|
e6f612b5e8 | ||
|
|
51c41acd82 | ||
|
|
455f93fb7c | ||
|
|
48207c3b69 | ||
|
|
4de1caa40f | ||
|
|
60eaa8165c | ||
|
|
c1a5d0c624 | ||
|
|
af1790395a | ||
|
|
383c6d8d7e | ||
|
|
bc0d839693 | ||
|
|
8596562de5 | ||
|
|
5d09586853 | ||
|
|
a7cba078dd | ||
|
|
b3e9ee96fa | ||
|
|
8537a6b17e | ||
|
|
7c8d7dc5c2 | ||
|
|
8e23d663e6 | ||
|
|
8a3994bf80 | ||
|
|
8375f601ba | ||
|
|
c87c0fe662 | ||
|
|
73927b68ef | ||
|
|
cc1a62e5aa | ||
|
|
802020cb41 | ||
|
|
cdb92f7cf4 | ||
|
|
dc69bdec00 | ||
|
|
98073e9868 |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1 +0,0 @@
|
||||
paper_plot/data/big_graph_degree_data.npz filter=lfs diff=lfs merge=lfs -text
|
||||
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
50
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Bug Report
|
||||
description: Report a bug in LEANN
|
||||
labels: ["bug"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: What happened?
|
||||
description: A clear description of the bug
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduce
|
||||
attributes:
|
||||
label: How to reproduce
|
||||
placeholder: |
|
||||
1. Install with...
|
||||
2. Run command...
|
||||
3. See error
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: error
|
||||
attributes:
|
||||
label: Error message
|
||||
description: Paste any error messages
|
||||
render: shell
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: LEANN Version
|
||||
placeholder: "0.1.0"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: os
|
||||
attributes:
|
||||
label: Operating System
|
||||
options:
|
||||
- macOS
|
||||
- Linux
|
||||
- Windows
|
||||
- Docker
|
||||
validations:
|
||||
required: true
|
||||
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
8
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Documentation
|
||||
url: https://github.com/LEANN-RAG/LEANN-RAG/tree/main/docs
|
||||
about: Read the docs first
|
||||
- name: Discussions
|
||||
url: https://github.com/LEANN-RAG/LEANN-RAG/discussions
|
||||
about: Ask questions and share ideas
|
||||
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
27
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
name: Feature Request
|
||||
description: Suggest a new feature for LEANN
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: problem
|
||||
attributes:
|
||||
label: What problem does this solve?
|
||||
description: Describe the problem or need
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: solution
|
||||
attributes:
|
||||
label: Proposed solution
|
||||
description: How would you like this to work?
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: example
|
||||
attributes:
|
||||
label: Example usage
|
||||
description: Show how the API might look
|
||||
render: python
|
||||
13
.github/pull_request_template.md
vendored
Normal file
13
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
## What does this PR do?
|
||||
|
||||
<!-- Brief description of your changes -->
|
||||
|
||||
## Related Issues
|
||||
|
||||
Fixes #
|
||||
|
||||
## Checklist
|
||||
|
||||
- [ ] Tests pass (`uv run pytest`)
|
||||
- [ ] Code formatted (`ruff format` and `ruff check`)
|
||||
- [ ] Pre-commit hooks pass (`pre-commit run --all-files`)
|
||||
3
.github/workflows/build-and-publish.yml
vendored
3
.github/workflows/build-and-publish.yml
vendored
@@ -5,7 +5,8 @@ on:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
|
||||
474
.github/workflows/build-reusable.yml
vendored
474
.github/workflows/build-reusable.yml
vendored
@@ -10,109 +10,268 @@ on:
|
||||
default: ''
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04
|
||||
python: '3.9'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.11'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.12'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
- os: macos-latest
|
||||
python: '3.9'
|
||||
- os: macos-latest
|
||||
python: '3.10'
|
||||
- os: macos-latest
|
||||
python: '3.11'
|
||||
- os: macos-latest
|
||||
python: '3.12'
|
||||
- os: macos-latest
|
||||
python: '3.13'
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
lint:
|
||||
name: Lint and Format Check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Run pre-commit with only lint group (no project deps)
|
||||
run: |
|
||||
uv run --only-group lint pre-commit run --all-files --show-diff-on-failure
|
||||
|
||||
type-check:
|
||||
name: Type Check with ty
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install ty
|
||||
run: uv tool install ty
|
||||
|
||||
- name: Run ty type checker
|
||||
run: |
|
||||
# Run ty on core packages, apps, and tests
|
||||
ty check packages/leann-core/src apps tests
|
||||
|
||||
build:
|
||||
needs: [lint, type-check]
|
||||
name: Build ${{ matrix.os }} Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
# Note: Python 3.9 dropped - uses PEP 604 union syntax (str | None)
|
||||
# which requires Python 3.10+
|
||||
- os: ubuntu-22.04
|
||||
python: '3.10'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.11'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.12'
|
||||
- os: ubuntu-22.04
|
||||
python: '3.13'
|
||||
# ARM64 Linux builds
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.10'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.11'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.12'
|
||||
- os: ubuntu-24.04-arm
|
||||
python: '3.13'
|
||||
- os: macos-14
|
||||
python: '3.10'
|
||||
- os: macos-14
|
||||
python: '3.11'
|
||||
- os: macos-14
|
||||
python: '3.12'
|
||||
- os: macos-14
|
||||
python: '3.13'
|
||||
- os: macos-15
|
||||
python: '3.10'
|
||||
- os: macos-15
|
||||
python: '3.11'
|
||||
- os: macos-15
|
||||
python: '3.12'
|
||||
- os: macos-15
|
||||
python: '3.13'
|
||||
# Intel Mac builds (x86_64) - replaces deprecated macos-13
|
||||
# Note: Python 3.13 excluded - PyTorch has no wheels for macOS x86_64 + Python 3.13
|
||||
# (PyTorch <=2.4.1 lacks cp313, PyTorch >=2.5.0 dropped Intel Mac support)
|
||||
- os: macos-15-intel
|
||||
python: '3.10'
|
||||
- os: macos-15-intel
|
||||
python: '3.11'
|
||||
- os: macos-15-intel
|
||||
python: '3.12'
|
||||
# macOS 26 (beta) - arm64
|
||||
- os: macos-26
|
||||
python: '3.10'
|
||||
- os: macos-26
|
||||
python: '3.11'
|
||||
- os: macos-26
|
||||
python: '3.12'
|
||||
- os: macos-26
|
||||
python: '3.13'
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
with:
|
||||
ref: ${{ inputs.ref }}
|
||||
submodules: recursive
|
||||
|
||||
- name: Install uv and Python
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
|
||||
|
||||
- name: Install system dependencies (Ubuntu)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libomp-dev libboost-all-dev protobuf-compiler libzmq3-dev \
|
||||
pkg-config libopenblas-dev patchelf libabsl-dev libaio-dev libprotobuf-dev
|
||||
|
||||
# Install Intel MKL for DiskANN
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/latest/lib/intel64:$LD_LIBRARY_PATH" >> $GITHUB_ENV
|
||||
|
||||
pkg-config libabsl-dev libaio-dev libprotobuf-dev \
|
||||
patchelf
|
||||
|
||||
# Debug: Show system information
|
||||
echo "🔍 System Information:"
|
||||
echo "Architecture: $(uname -m)"
|
||||
echo "OS: $(uname -a)"
|
||||
echo "CPU info: $(lscpu | head -5)"
|
||||
|
||||
# Install math library based on architecture
|
||||
ARCH=$(uname -m)
|
||||
echo "🔍 Setting up math library for architecture: $ARCH"
|
||||
|
||||
if [[ "$ARCH" == "x86_64" ]]; then
|
||||
# Install Intel MKL for DiskANN on x86_64
|
||||
echo "📦 Installing Intel MKL for x86_64..."
|
||||
wget -q https://registrationcenter-download.intel.com/akdlm/IRC_NAS/79153e0f-74d7-45af-b8c2-258941adf58a/intel-onemkl-2025.0.0.940.sh
|
||||
sudo sh intel-onemkl-2025.0.0.940.sh -a --components intel.oneapi.lin.mkl.devel --action install --eula accept -s
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
echo "MKLROOT=/opt/intel/oneapi/mkl/latest" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=/opt/intel/oneapi/compiler/latest/linux/compiler/lib/intel64_lin" >> $GITHUB_ENV
|
||||
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/mkl/latest/lib/intel64" >> $GITHUB_ENV
|
||||
echo "✅ Intel MKL installed for x86_64"
|
||||
|
||||
# Debug: Check MKL installation
|
||||
echo "🔍 MKL Installation Check:"
|
||||
ls -la /opt/intel/oneapi/mkl/latest/ || echo "MKL directory not found"
|
||||
ls -la /opt/intel/oneapi/mkl/latest/lib/ || echo "MKL lib directory not found"
|
||||
|
||||
elif [[ "$ARCH" == "aarch64" ]]; then
|
||||
# Use OpenBLAS for ARM64 (MKL installer not compatible with ARM64)
|
||||
echo "📦 Installing OpenBLAS for ARM64..."
|
||||
sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev
|
||||
echo "✅ OpenBLAS installed for ARM64"
|
||||
|
||||
# Debug: Check OpenBLAS installation
|
||||
echo "🔍 OpenBLAS Installation Check:"
|
||||
dpkg -l | grep openblas || echo "OpenBLAS package not found"
|
||||
ls -la /usr/lib/aarch64-linux-gnu/openblas/ || echo "OpenBLAS directory not found"
|
||||
fi
|
||||
|
||||
# Debug: Show final library paths
|
||||
echo "🔍 Final LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
||||
|
||||
- name: Install system dependencies (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
brew install llvm libomp boost protobuf zeromq
|
||||
|
||||
# Don't install LLVM, use system clang for better compatibility
|
||||
brew install libomp boost protobuf zeromq
|
||||
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
uv pip install --system scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --system auditwheel
|
||||
uv python install ${{ matrix.python }}
|
||||
uv venv --python ${{ matrix.python }} .uv-build
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
BUILD_PY=".uv-build\\Scripts\\python.exe"
|
||||
else
|
||||
uv pip install --system delocate
|
||||
BUILD_PY=".uv-build/bin/python"
|
||||
fi
|
||||
|
||||
uv pip install --python "$BUILD_PY" scikit-build-core numpy swig Cython pybind11
|
||||
if [[ "$RUNNER_OS" == "Linux" ]]; then
|
||||
uv pip install --python "$BUILD_PY" auditwheel
|
||||
else
|
||||
uv pip install --python "$BUILD_PY" delocate
|
||||
fi
|
||||
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
echo "$(pwd)\\.uv-build\\Scripts" >> $GITHUB_PATH
|
||||
else
|
||||
echo "$(pwd)/.uv-build/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Set macOS environment variables
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Use brew --prefix to automatically detect Homebrew installation path
|
||||
HOMEBREW_PREFIX=$(brew --prefix)
|
||||
echo "HOMEBREW_PREFIX=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||
echo "OpenMP_ROOT=${HOMEBREW_PREFIX}/opt/libomp" >> $GITHUB_ENV
|
||||
|
||||
# Set CMAKE_PREFIX_PATH to let CMake find all packages automatically
|
||||
echo "CMAKE_PREFIX_PATH=${HOMEBREW_PREFIX}" >> $GITHUB_ENV
|
||||
|
||||
# Set compiler flags for OpenMP (required for both backends)
|
||||
echo "LDFLAGS=-L${HOMEBREW_PREFIX}/opt/libomp/lib" >> $GITHUB_ENV
|
||||
echo "CPPFLAGS=-I${HOMEBREW_PREFIX}/opt/libomp/include" >> $GITHUB_ENV
|
||||
|
||||
- name: Build packages
|
||||
run: |
|
||||
# Build core (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
cd packages/leann-core
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
# Build HNSW backend
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# Set deployment target based on runner
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
uv build --wheel --python python
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Build DiskANN backend
|
||||
cd packages/leann-backend-diskann
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ uv build --wheel --python python
|
||||
if [[ "${{ matrix.os }}" == macos-* ]]; then
|
||||
# Use system clang for better compatibility
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
# Set deployment target based on runner
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
else
|
||||
uv build --wheel --python python
|
||||
uv build --wheel --python ${{ matrix.python }} --find-links ${GITHUB_WORKSPACE}/packages/leann-core/dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Build meta package (platform independent)
|
||||
if [[ "${{ matrix.os }}" == ubuntu-* ]]; then
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
fi
|
||||
|
||||
cd packages/leann
|
||||
uv build
|
||||
cd ../..
|
||||
|
||||
- name: Repair wheels (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
@@ -124,7 +283,7 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
@@ -133,35 +292,194 @@ jobs:
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
- name: Repair wheels (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Determine deployment target based on runner OS
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
HNSW_TARGET="14.0"
|
||||
DISKANN_TARGET="14.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
HNSW_TARGET="15.0"
|
||||
DISKANN_TARGET="15.0"
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
HNSW_TARGET="26.0"
|
||||
DISKANN_TARGET="26.0"
|
||||
fi
|
||||
|
||||
# Repair HNSW wheel
|
||||
cd packages/leann-backend-hnsw
|
||||
if [ -d dist ]; then
|
||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||
export MACOSX_DEPLOYMENT_TARGET=$HNSW_TARGET
|
||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $HNSW_TARGET dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
# Repair DiskANN wheel
|
||||
cd packages/leann-backend-diskann
|
||||
if [ -d dist ]; then
|
||||
delocate-wheel -w dist_repaired -v dist/*.whl
|
||||
export MACOSX_DEPLOYMENT_TARGET=$DISKANN_TARGET
|
||||
delocate-wheel -w dist_repaired -v --require-target-macos-version $DISKANN_TARGET dist/*.whl
|
||||
rm -rf dist
|
||||
mv dist_repaired dist
|
||||
fi
|
||||
cd ../..
|
||||
|
||||
|
||||
- name: List built packages
|
||||
run: |
|
||||
echo "📦 Built packages:"
|
||||
find packages/*/dist -name "*.whl" -o -name "*.tar.gz" | sort
|
||||
|
||||
|
||||
|
||||
- name: Install built packages for testing
|
||||
run: |
|
||||
# Create uv-managed virtual environment with the requested interpreter
|
||||
uv python install ${{ matrix.python }}
|
||||
uv venv --python ${{ matrix.python }}
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
if [[ "$RUNNER_OS" == "Windows" ]]; then
|
||||
UV_PY=".venv\\Scripts\\python.exe"
|
||||
else
|
||||
UV_PY=".venv/bin/python"
|
||||
fi
|
||||
|
||||
# Install test dependency group only (avoids reinstalling project package)
|
||||
uv pip install --python "$UV_PY" --group test
|
||||
|
||||
# Install core wheel built in this job
|
||||
CORE_WHL=$(find packages/leann-core/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||
if [[ -n "$CORE_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$CORE_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" packages/leann-core/dist/*.tar.gz
|
||||
fi
|
||||
|
||||
PY_TAG=$($UV_PY -c "import sys; print(f'cp{sys.version_info[0]}{sys.version_info[1]}')")
|
||||
|
||||
if [[ "$RUNNER_OS" == "macOS" ]]; then
|
||||
# macos-15-intel runs macOS 15, so target 15.0 (system libraries require it)
|
||||
if [[ "${{ matrix.os }}" == "macos-15-intel" ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-14* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=14.0
|
||||
elif [[ "${{ matrix.os }}" == macos-15* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=15.0
|
||||
elif [[ "${{ matrix.os }}" == macos-26* ]]; then
|
||||
export MACOSX_DEPLOYMENT_TARGET=26.0
|
||||
fi
|
||||
fi
|
||||
|
||||
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||
if [[ -z "$HNSW_WHL" ]]; then
|
||||
HNSW_WHL=$(find packages/leann-backend-hnsw/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||
fi
|
||||
if [[ -n "$HNSW_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$HNSW_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" ./packages/leann-backend-hnsw
|
||||
fi
|
||||
|
||||
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-${PY_TAG}-*.whl" -print -quit)
|
||||
if [[ -z "$DISKANN_WHL" ]]; then
|
||||
DISKANN_WHL=$(find packages/leann-backend-diskann/dist -maxdepth 1 -name "*-py3-*.whl" -print -quit)
|
||||
fi
|
||||
if [[ -n "$DISKANN_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$DISKANN_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" ./packages/leann-backend-diskann
|
||||
fi
|
||||
|
||||
LEANN_WHL=$(find packages/leann/dist -maxdepth 1 -name "*.whl" -print -quit)
|
||||
if [[ -n "$LEANN_WHL" ]]; then
|
||||
uv pip install --python "$UV_PY" "$LEANN_WHL"
|
||||
else
|
||||
uv pip install --python "$UV_PY" packages/leann/dist/*.tar.gz
|
||||
fi
|
||||
|
||||
- name: Run tests with pytest
|
||||
env:
|
||||
CI: true
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
HF_HUB_DISABLE_SYMLINKS: 1
|
||||
TOKENIZERS_PARALLELISM: false
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 0
|
||||
OMP_NUM_THREADS: 1
|
||||
MKL_NUM_THREADS: 1
|
||||
run: |
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
pytest tests/ -v --tb=short
|
||||
|
||||
- name: Run sanity checks (optional)
|
||||
run: |
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
|
||||
# Run distance function tests if available
|
||||
if [ -f test/sanity_checks/test_distance_functions.py ]; then
|
||||
echo "Running distance function sanity checks..."
|
||||
python test/sanity_checks/test_distance_functions.py || echo "⚠️ Distance function test failed, continuing..."
|
||||
fi
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: packages-${{ matrix.os }}-py${{ matrix.python }}
|
||||
path: packages/*/dist/
|
||||
path: packages/*/dist/
|
||||
|
||||
|
||||
arch-smoke:
|
||||
name: Arch Linux smoke test (install & import)
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: archlinux:latest
|
||||
|
||||
steps:
|
||||
- name: Prepare system
|
||||
run: |
|
||||
pacman -Syu --noconfirm
|
||||
pacman -S --noconfirm python python-pip gcc git zlib openssl
|
||||
|
||||
- name: Download ALL wheel artifacts from this run
|
||||
uses: actions/download-artifact@v5
|
||||
with:
|
||||
# Don't specify name, download all artifacts
|
||||
path: ./wheels
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Create virtual environment and install wheels
|
||||
run: |
|
||||
uv venv
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
uv pip install --find-links wheels leann-core
|
||||
uv pip install --find-links wheels leann-backend-hnsw
|
||||
uv pip install --find-links wheels leann-backend-diskann
|
||||
uv pip install --find-links wheels leann
|
||||
|
||||
- name: Import & tiny runtime check
|
||||
env:
|
||||
OMP_NUM_THREADS: 1
|
||||
MKL_NUM_THREADS: 1
|
||||
run: |
|
||||
source .venv/bin/activate || source .venv/Scripts/activate
|
||||
python - <<'PY'
|
||||
import leann
|
||||
import leann_backend_hnsw as h
|
||||
import leann_backend_diskann as d
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
b = LeannBuilder(backend_name="hnsw")
|
||||
b.add_text("hello arch")
|
||||
b.build_index("arch_demo.leann")
|
||||
s = LeannSearcher("arch_demo.leann")
|
||||
print("search:", s.search("hello", top_k=1))
|
||||
PY
|
||||
|
||||
19
.github/workflows/link-check.yml
vendored
Normal file
19
.github/workflows/link-check.yml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
name: Link Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main, master ]
|
||||
pull_request:
|
||||
schedule:
|
||||
- cron: "0 3 * * 1"
|
||||
|
||||
jobs:
|
||||
link-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: lycheeverse/lychee-action@v2
|
||||
with:
|
||||
args: --no-progress --insecure --user-agent 'curl/7.68.0' --exclude '.*api\.star-history\.com.*' --accept 200,201,202,203,204,205,206,207,208,226,300,301,302,303,304,305,306,307,308,503 README.md docs/ apps/ examples/ benchmarks/
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
43
.github/workflows/release-manual.yml
vendored
43
.github/workflows/release-manual.yml
vendored
@@ -16,18 +16,21 @@ jobs:
|
||||
contents: write
|
||||
outputs:
|
||||
commit-sha: ${{ steps.push.outputs.commit-sha }}
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Validate version
|
||||
run: |
|
||||
if ! [[ "${{ inputs.version }}" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "❌ Invalid version format"
|
||||
# Remove 'v' prefix if present for validation
|
||||
VERSION_CLEAN="${{ inputs.version }}"
|
||||
VERSION_CLEAN="${VERSION_CLEAN#v}"
|
||||
if ! [[ "$VERSION_CLEAN" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "❌ Invalid version format. Expected format: X.Y.Z or vX.Y.Z"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Version format valid"
|
||||
|
||||
echo "✅ Version format valid: ${{ inputs.version }}"
|
||||
|
||||
- name: Update versions and push
|
||||
id: push
|
||||
run: |
|
||||
@@ -35,7 +38,7 @@ jobs:
|
||||
CURRENT_VERSION=$(grep "^version" packages/leann-core/pyproject.toml | cut -d'"' -f2)
|
||||
echo "Current version: $CURRENT_VERSION"
|
||||
echo "Target version: ${{ inputs.version }}"
|
||||
|
||||
|
||||
if [ "$CURRENT_VERSION" = "${{ inputs.version }}" ]; then
|
||||
echo "⚠️ Version is already ${{ inputs.version }}, skipping update"
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
@@ -49,7 +52,7 @@ jobs:
|
||||
COMMIT_SHA=$(git rev-parse HEAD)
|
||||
echo "✅ Pushed version update: $COMMIT_SHA"
|
||||
fi
|
||||
|
||||
|
||||
echo "commit-sha=$COMMIT_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
build-packages:
|
||||
@@ -57,7 +60,7 @@ jobs:
|
||||
needs: update-version
|
||||
uses: ./.github/workflows/build-reusable.yml
|
||||
with:
|
||||
ref: 'main'
|
||||
ref: 'main'
|
||||
|
||||
publish:
|
||||
name: Publish and Release
|
||||
@@ -66,26 +69,26 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: 'main'
|
||||
|
||||
ref: 'main'
|
||||
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: dist-artifacts
|
||||
|
||||
|
||||
- name: Collect packages
|
||||
run: |
|
||||
mkdir -p dist
|
||||
find dist-artifacts -name "*.whl" -exec cp {} dist/ \;
|
||||
find dist-artifacts -name "*.tar.gz" -exec cp {} dist/ \;
|
||||
|
||||
|
||||
echo "📦 Packages to publish:"
|
||||
ls -la dist/
|
||||
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
@@ -95,12 +98,12 @@ jobs:
|
||||
echo "❌ PYPI_API_TOKEN not configured!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
pip install twine
|
||||
twine upload dist/* --skip-existing --verbose
|
||||
|
||||
|
||||
echo "✅ Published to PyPI!"
|
||||
|
||||
|
||||
- name: Create release
|
||||
run: |
|
||||
# Check if tag already exists
|
||||
@@ -111,7 +114,7 @@ jobs:
|
||||
git push origin "v${{ inputs.version }}"
|
||||
echo "✅ Created and pushed tag v${{ inputs.version }}"
|
||||
fi
|
||||
|
||||
|
||||
# Check if release already exists
|
||||
if gh release view "v${{ inputs.version }}" >/dev/null 2>&1; then
|
||||
echo "⚠️ Release v${{ inputs.version }} already exists, skipping release creation"
|
||||
@@ -123,4 +126,4 @@ jobs:
|
||||
echo "✅ Created GitHub release v${{ inputs.version }}"
|
||||
fi
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
37
.gitignore
vendored
37
.gitignore
vendored
@@ -9,7 +9,7 @@ demo/indices/
|
||||
outputs/
|
||||
*.pkl
|
||||
*.pdf
|
||||
*.idx
|
||||
*.idx
|
||||
*.map
|
||||
.history/
|
||||
lm_eval.egg-info/
|
||||
@@ -18,9 +18,12 @@ demo/experiment_results/**/*.json
|
||||
*.eml
|
||||
*.emlx
|
||||
*.json
|
||||
*.png
|
||||
!.vscode/*.json
|
||||
*.sh
|
||||
*.txt
|
||||
!CMakeLists.txt
|
||||
!llms.txt
|
||||
latency_breakdown*.json
|
||||
experiment_results/eval_results/diskann/*.json
|
||||
aws/
|
||||
@@ -34,11 +37,15 @@ build/
|
||||
nprobe_logs/
|
||||
micro/results
|
||||
micro/contriever-INT8
|
||||
examples/data/*
|
||||
!examples/data/2501.14312v1 (1).pdf
|
||||
!examples/data/2506.08276v1.pdf
|
||||
!examples/data/PrideandPrejudice.txt
|
||||
!examples/data/README.md
|
||||
data/*
|
||||
!data/2501.14312v1 (1).pdf
|
||||
!data/2506.08276v1.pdf
|
||||
!data/PrideandPrejudice.txt
|
||||
!data/huawei_pangu.md
|
||||
!data/ground_truth/
|
||||
!data/indices/
|
||||
!data/queries/
|
||||
!data/.gitattributes
|
||||
*.qdstrm
|
||||
benchmark_results/
|
||||
results/
|
||||
@@ -84,5 +91,21 @@ packages/leann-backend-diskann/third_party/DiskANN/_deps/
|
||||
|
||||
*.meta.json
|
||||
*.passages.json
|
||||
*.npy
|
||||
*.db
|
||||
batchtest.py
|
||||
tests/__pytest_cache__/
|
||||
tests/__pycache__/
|
||||
benchmarks/data/
|
||||
|
||||
batchtest.py
|
||||
## multi vector
|
||||
apps/multimodal/vision-based-pdf-multi-vector/multi-vector-colpali-native-weaviate.py
|
||||
|
||||
# Ignore all PDFs (keep data exceptions above) and do not track demo PDFs
|
||||
# If you need to commit a specific demo PDF, remove this negation locally.
|
||||
# The following line used to force-add a large demo PDF; remove it to satisfy pre-commit:
|
||||
# !apps/multimodal/vision-based-pdf-multi-vector/pdfs/2004.12832v2.pdf
|
||||
!apps/multimodal/vision-based-pdf-multi-vector/fig/*
|
||||
|
||||
# AUR build directory (Arch Linux)
|
||||
paru-bin/
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -14,3 +14,6 @@
|
||||
[submodule "packages/leann-backend-hnsw/third_party/libzmq"]
|
||||
path = packages/leann-backend-hnsw/third_party/libzmq
|
||||
url = https://github.com/zeromq/libzmq.git
|
||||
[submodule "packages/astchunk-leann"]
|
||||
path = packages/astchunk-leann
|
||||
url = https://github.com/yichuan-w/astchunk-leann.git
|
||||
|
||||
17
.pre-commit-config.yaml
Normal file
17
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,17 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-merge-conflict
|
||||
- id: debug-statements
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.7 # Fixed version to match pyproject.toml
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- id: ruff-format
|
||||
5
.vscode/extensions.json
vendored
Normal file
5
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"recommendations": [
|
||||
"charliermarsh.ruff",
|
||||
]
|
||||
}
|
||||
22
.vscode/settings.json
vendored
Normal file
22
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit",
|
||||
"source.fixAll": "explicit"
|
||||
},
|
||||
"editor.insertSpaces": true,
|
||||
"editor.tabSize": 4
|
||||
},
|
||||
"ruff.enable": true,
|
||||
"files.watcherExclude": {
|
||||
"**/.venv/**": true,
|
||||
"**/__pycache__/**": true,
|
||||
"**/*.egg-info/**": true,
|
||||
"**/build/**": true,
|
||||
"**/dist/**": true
|
||||
}
|
||||
}
|
||||
0
apps/__init__.py
Normal file
0
apps/__init__.py
Normal file
413
apps/base_rag_example.py
Normal file
413
apps/base_rag_example.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Base class for unified RAG examples interface.
|
||||
Provides common parameters and functionality for all RAG examples.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
from leann.api import LeannBuilder, LeannChat
|
||||
|
||||
# Optional import: older PyPI builds may not include interactive_utils
|
||||
try:
|
||||
from leann.interactive_utils import create_rag_session
|
||||
except ImportError:
|
||||
|
||||
def create_rag_session(app_name: str, data_description: str):
|
||||
class _SimpleSession:
|
||||
def run_interactive_loop(self, handler):
|
||||
print(f"Interactive session for {app_name}: {data_description}")
|
||||
print("Interactive mode not available in this build")
|
||||
|
||||
return _SimpleSession()
|
||||
|
||||
|
||||
from leann.registry import register_project_directory
|
||||
|
||||
# Optional import: older PyPI builds may not include settings
|
||||
try:
|
||||
from leann.settings import resolve_ollama_host, resolve_openai_api_key, resolve_openai_base_url
|
||||
except ImportError:
|
||||
# Minimal fallbacks if settings helpers are unavailable
|
||||
import os
|
||||
|
||||
def resolve_ollama_host(value: str | None) -> str | None:
|
||||
return value or os.getenv("LEANN_OLLAMA_HOST") or os.getenv("OLLAMA_HOST")
|
||||
|
||||
def resolve_openai_api_key(value: str | None) -> str | None:
|
||||
return value or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
def resolve_openai_base_url(value: str | None) -> str | None:
|
||||
return value or os.getenv("OPENAI_BASE_URL")
|
||||
|
||||
|
||||
dotenv.load_dotenv()
|
||||
|
||||
|
||||
class BaseRAGExample(ABC):
|
||||
"""Base class for all RAG examples with unified interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
default_index_name: str,
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.default_index_name = default_index_name
|
||||
self.parser = self._create_parser()
|
||||
|
||||
def _create_parser(self) -> argparse.ArgumentParser:
|
||||
"""Create argument parser with common parameters."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=self.description, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
|
||||
# Core parameters (all examples share these)
|
||||
core_group = parser.add_argument_group("Core Parameters")
|
||||
core_group.add_argument(
|
||||
"--index-dir",
|
||||
type=str,
|
||||
default=f"./{self.default_index_name}",
|
||||
help=f"Directory to store the index (default: ./{self.default_index_name})",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Query to run (if not provided, will run in interactive mode)",
|
||||
)
|
||||
# Allow subclasses to override default max_items
|
||||
max_items_default = getattr(self, "max_items_default", -1)
|
||||
core_group.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=max_items_default,
|
||||
help="Maximum number of items to process -1 for all, means index all documents, and you should set it to a reasonable number if you have a large dataset and try at the first time)",
|
||||
)
|
||||
core_group.add_argument(
|
||||
"--force-rebuild", action="store_true", help="Force rebuild index even if it exists"
|
||||
)
|
||||
|
||||
# Embedding parameters
|
||||
embedding_group = parser.add_argument_group("Embedding Parameters")
|
||||
# Allow subclasses to override default embedding_model
|
||||
embedding_model_default = getattr(self, "embedding_model_default", "facebook/contriever")
|
||||
embedding_group.add_argument(
|
||||
"--embedding-model",
|
||||
type=str,
|
||||
default=embedding_model_default,
|
||||
help=f"Embedding model to use (default: {embedding_model_default}), we provide facebook/contriever, text-embedding-3-small,mlx-community/Qwen3-Embedding-0.6B-8bit or nomic-embed-text",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-mode",
|
||||
type=str,
|
||||
default="sentence-transformers",
|
||||
choices=["sentence-transformers", "openai", "mlx", "ollama"],
|
||||
help="Embedding backend mode (default: sentence-transformers), we provide sentence-transformers, openai, mlx, or ollama",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Override Ollama-compatible embedding host",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible embedding services",
|
||||
)
|
||||
embedding_group.add_argument(
|
||||
"--embedding-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for embedding service (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# LLM parameters
|
||||
llm_group = parser.add_argument_group("LLM Parameters")
|
||||
llm_group.add_argument(
|
||||
"--llm",
|
||||
type=str,
|
||||
default="openai",
|
||||
choices=["openai", "ollama", "hf", "simulated"],
|
||||
help="LLM backend: openai, ollama, or hf (default: openai)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model name (default: gpt-4o) e.g., gpt-4o-mini, llama3.2:1b, Qwen/Qwen2.5-1.5B-Instruct",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-host",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Host for Ollama-compatible APIs (defaults to LEANN_OLLAMA_HOST/OLLAMA_HOST)",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--thinking-budget",
|
||||
type=str,
|
||||
choices=["low", "medium", "high"],
|
||||
default=None,
|
||||
help="Thinking budget for reasoning models (low/medium/high). Supported by GPT-Oss:20b and other reasoning models.",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-api-base",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base URL for OpenAI-compatible APIs",
|
||||
)
|
||||
llm_group.add_argument(
|
||||
"--llm-api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="API key for OpenAI-compatible APIs (defaults to OPENAI_API_KEY)",
|
||||
)
|
||||
|
||||
# AST Chunking parameters
|
||||
ast_group = parser.add_argument_group("AST Chunking Parameters")
|
||||
ast_group.add_argument(
|
||||
"--use-ast-chunking",
|
||||
action="store_true",
|
||||
help="Enable AST-aware chunking for code files (requires astchunk)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-size",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Maximum CHARACTERS per AST chunk (default: 300). Final chunks may be larger due to overlap. For 512 token models: recommended 300 chars",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-chunk-overlap",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Overlap between AST chunks in CHARACTERS (default: 64). Added to chunk size, not included in it",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--code-file-extensions",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Additional code file extensions to process with AST chunking (e.g., .py .java .cs .ts)",
|
||||
)
|
||||
ast_group.add_argument(
|
||||
"--ast-fallback-traditional",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Fall back to traditional chunking if AST chunking fails (default: True)",
|
||||
)
|
||||
|
||||
# Search parameters
|
||||
search_group = parser.add_argument_group("Search Parameters")
|
||||
search_group.add_argument(
|
||||
"--top-k", type=int, default=20, help="Number of results to retrieve (default: 20)"
|
||||
)
|
||||
search_group.add_argument(
|
||||
"--search-complexity",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Search complexity for graph traversal (default: 64)",
|
||||
)
|
||||
|
||||
# Index building parameters
|
||||
index_group = parser.add_argument_group("Index Building Parameters")
|
||||
index_group.add_argument(
|
||||
"--backend-name",
|
||||
type=str,
|
||||
default="hnsw",
|
||||
choices=["hnsw", "diskann"],
|
||||
help="Backend to use for index (default: hnsw)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--graph-degree",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Graph degree for index construction (default: 32)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--build-complexity",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Build complexity for index construction (default: 64)",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-compact",
|
||||
action="store_true",
|
||||
help="Disable compact index storage",
|
||||
)
|
||||
index_group.add_argument(
|
||||
"--no-recompute",
|
||||
action="store_true",
|
||||
help="Disable embedding recomputation",
|
||||
)
|
||||
|
||||
# Add source-specific parameters
|
||||
self._add_specific_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add source-specific arguments. Override in subclasses."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load data from the source. Returns list of text chunks as dicts with 'text' and 'metadata' keys."""
|
||||
pass
|
||||
|
||||
def get_llm_config(self, args) -> dict[str, Any]:
|
||||
"""Get LLM configuration based on arguments."""
|
||||
config = {"type": args.llm}
|
||||
|
||||
if args.llm == "openai":
|
||||
config["model"] = args.llm_model or "gpt-4o"
|
||||
config["base_url"] = resolve_openai_base_url(args.llm_api_base)
|
||||
resolved_key = resolve_openai_api_key(args.llm_api_key)
|
||||
if resolved_key:
|
||||
config["api_key"] = resolved_key
|
||||
elif args.llm == "ollama":
|
||||
config["model"] = args.llm_model or "llama3.2:1b"
|
||||
config["host"] = resolve_ollama_host(args.llm_host)
|
||||
elif args.llm == "hf":
|
||||
config["model"] = args.llm_model or "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
elif args.llm == "simulated":
|
||||
# Simulated LLM doesn't need additional configuration
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
||||
"""Build LEANN index from text chunks (dicts with 'text' and 'metadata' keys)."""
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
|
||||
print(f"\n[Building Index] Creating {self.name} index...")
|
||||
print(f"Total text chunks: {len(texts)}")
|
||||
|
||||
embedding_options: dict[str, Any] = {}
|
||||
if args.embedding_mode == "ollama":
|
||||
embedding_options["host"] = resolve_ollama_host(args.embedding_host)
|
||||
elif args.embedding_mode == "openai":
|
||||
embedding_options["base_url"] = resolve_openai_base_url(args.embedding_api_base)
|
||||
resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key)
|
||||
if resolved_embedding_key:
|
||||
embedding_options["api_key"] = resolved_embedding_key
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=args.embedding_model,
|
||||
embedding_mode=args.embedding_mode,
|
||||
embedding_options=embedding_options or None,
|
||||
graph_degree=args.graph_degree,
|
||||
complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
is_recompute=not args.no_recompute,
|
||||
num_threads=1, # Force single-threaded mode
|
||||
)
|
||||
|
||||
# Add texts in batches for better progress tracking
|
||||
batch_size = 1000
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
for item in batch:
|
||||
# Handle both dict format (from create_text_chunks) and plain strings
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text", "")
|
||||
metadata = item.get("metadata")
|
||||
builder.add_text(text, metadata)
|
||||
else:
|
||||
builder.add_text(item)
|
||||
print(f"Added {min(i + batch_size, len(texts))}/{len(texts)} texts...")
|
||||
|
||||
print("Building index structure...")
|
||||
builder.build_index(index_path)
|
||||
print(f"Index saved to: {index_path}")
|
||||
|
||||
# Register project directory so leann list can discover this index
|
||||
# The index is saved as args.index_dir/index_name.leann
|
||||
# We want to register the current working directory where the app is run
|
||||
register_project_directory(Path.cwd())
|
||||
|
||||
return index_path
|
||||
|
||||
async def run_interactive_chat(self, args, index_path: str):
|
||||
"""Run interactive chat with the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
system_prompt=f"You are a helpful assistant that answers questions about {self.name} data.",
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
# Create interactive session
|
||||
session = create_rag_session(
|
||||
app_name=self.name.lower().replace(" ", "_"), data_description=self.name
|
||||
)
|
||||
|
||||
def handle_query(query: str):
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query,
|
||||
top_k=args.top_k,
|
||||
complexity=args.search_complexity,
|
||||
llm_kwargs=llm_kwargs,
|
||||
)
|
||||
print(f"\nAssistant: {response}\n")
|
||||
|
||||
session.run_interactive_loop(handle_query)
|
||||
|
||||
async def run_single_query(self, args, index_path: str, query: str):
|
||||
"""Run a single query against the index."""
|
||||
chat = LeannChat(
|
||||
index_path,
|
||||
llm_config=self.get_llm_config(args),
|
||||
complexity=args.search_complexity,
|
||||
)
|
||||
|
||||
print(f"\n[Query]: \033[36m{query}\033[0m")
|
||||
|
||||
# Prepare LLM kwargs with thinking budget if specified
|
||||
llm_kwargs = {}
|
||||
if hasattr(args, "thinking_budget") and args.thinking_budget:
|
||||
llm_kwargs["thinking_budget"] = args.thinking_budget
|
||||
|
||||
response = chat.ask(
|
||||
query, top_k=args.top_k, complexity=args.search_complexity, llm_kwargs=llm_kwargs
|
||||
)
|
||||
print(f"\n[Response]: \033[36m{response}\033[0m")
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point for the example."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Check if index exists
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
index_exists = Path(args.index_dir).exists()
|
||||
|
||||
if not index_exists or args.force_rebuild:
|
||||
# Load data and build index
|
||||
print(f"\n{'Rebuilding' if index_exists else 'Building'} index...")
|
||||
texts = await self.load_data(args)
|
||||
|
||||
if not texts:
|
||||
print("No data found to index!")
|
||||
return
|
||||
|
||||
index_path = await self.build_index(args, texts)
|
||||
else:
|
||||
print(f"\nUsing existing index in {args.index_dir}")
|
||||
|
||||
# Run query or interactive mode
|
||||
if args.query:
|
||||
await self.run_single_query(args, index_path, args.query)
|
||||
else:
|
||||
await self.run_interactive_chat(args, index_path)
|
||||
172
apps/browser_rag.py
Normal file
172
apps/browser_rag.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Browser History RAG example using the unified interface.
|
||||
Supports Chrome browser history.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .history_data.history import ChromeHistoryReader
|
||||
|
||||
|
||||
class BrowserRAG(BaseRAGExample):
|
||||
"""RAG example for Chrome browser history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Browser History",
|
||||
description="Process and query Chrome browser history with LEANN",
|
||||
default_index_name="google_history_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add browser-specific arguments."""
|
||||
browser_group = parser.add_argument_group("Browser Parameters")
|
||||
browser_group.add_argument(
|
||||
"--chrome-profile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Chrome profile directory (auto-detected if not specified)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--auto-find-profiles",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Automatically find all Chrome profiles (default: True)",
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
browser_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _get_chrome_base_path(self) -> Path:
|
||||
"""Get the base Chrome profile path based on OS."""
|
||||
if sys.platform == "darwin":
|
||||
return Path.home() / "Library" / "Application Support" / "Google" / "Chrome"
|
||||
elif sys.platform.startswith("linux"):
|
||||
return Path.home() / ".config" / "google-chrome"
|
||||
elif sys.platform == "win32":
|
||||
return Path(os.environ["LOCALAPPDATA"]) / "Google" / "Chrome" / "User Data"
|
||||
else:
|
||||
raise ValueError(f"Unsupported platform: {sys.platform}")
|
||||
|
||||
def _find_chrome_profiles(self) -> list[Path]:
|
||||
"""Auto-detect all Chrome profiles."""
|
||||
base_path = self._get_chrome_base_path()
|
||||
if not base_path.exists():
|
||||
return []
|
||||
|
||||
profiles = []
|
||||
|
||||
# Check Default profile
|
||||
default_profile = base_path / "Default"
|
||||
if default_profile.exists() and (default_profile / "History").exists():
|
||||
profiles.append(default_profile)
|
||||
|
||||
# Check numbered profiles
|
||||
for item in base_path.iterdir():
|
||||
if item.is_dir() and item.name.startswith("Profile "):
|
||||
if (item / "History").exists():
|
||||
profiles.append(item)
|
||||
|
||||
return profiles
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load browser history and convert to text chunks."""
|
||||
# Determine Chrome profiles
|
||||
if args.chrome_profile and not args.auto_find_profiles:
|
||||
profile_dirs = [Path(args.chrome_profile)]
|
||||
else:
|
||||
print("Auto-detecting Chrome profiles...")
|
||||
profile_dirs = self._find_chrome_profiles()
|
||||
|
||||
# If specific profile given, filter to just that one
|
||||
if args.chrome_profile:
|
||||
profile_path = Path(args.chrome_profile)
|
||||
profile_dirs = [p for p in profile_dirs if p == profile_path]
|
||||
|
||||
if not profile_dirs:
|
||||
print("No Chrome profiles found!")
|
||||
print("Please specify --chrome-profile manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
|
||||
# Create reader
|
||||
reader = ChromeHistoryReader()
|
||||
|
||||
# Process each profile
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, profile_dir in enumerate(profile_dirs):
|
||||
print(f"\nProcessing profile {i + 1}/{len(profile_dirs)}: {profile_dir.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per profile
|
||||
max_per_profile = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_profile = remaining
|
||||
|
||||
# Load history
|
||||
documents = reader.load_data(
|
||||
chrome_profile_path=str(profile_dir),
|
||||
max_count=max_per_profile,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} history entries from this profile")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {profile_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No browser history found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal history entries processed: {len(all_documents)}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for browser history RAG
|
||||
print("\n🌐 Browser History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What websites did I visit about machine learning?'")
|
||||
print("- 'Find my search history about programming'")
|
||||
print("- 'What YouTube videos did I watch recently?'")
|
||||
print("- 'Show me websites about travel planning'")
|
||||
print("\nNote: Make sure Chrome is closed before running\n")
|
||||
|
||||
rag = BrowserRAG()
|
||||
asyncio.run(rag.run())
|
||||
0
apps/chatgpt_data/__init__.py
Normal file
0
apps/chatgpt_data/__init__.py
Normal file
413
apps/chatgpt_data/chatgpt_reader.py
Normal file
413
apps/chatgpt_data/chatgpt_reader.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
ChatGPT export data reader.
|
||||
|
||||
Reads and processes ChatGPT export data from chat.html files.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ChatGPTReader(BaseReader):
|
||||
"""
|
||||
ChatGPT export data reader.
|
||||
|
||||
Reads ChatGPT conversation data from exported chat.html files or zip archives.
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _extract_html_from_zip(self, zip_path: Path) -> str | None:
|
||||
"""
|
||||
Extract chat.html from ChatGPT export zip file.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the ChatGPT export zip file
|
||||
|
||||
Returns:
|
||||
HTML content as string, or None if not found
|
||||
"""
|
||||
try:
|
||||
with ZipFile(zip_path, "r") as zip_file:
|
||||
# Look for chat.html or conversations.html
|
||||
html_files = [
|
||||
f
|
||||
for f in zip_file.namelist()
|
||||
if f.endswith(".html") and ("chat" in f.lower() or "conversation" in f.lower())
|
||||
]
|
||||
|
||||
if not html_files:
|
||||
print(f"No HTML chat file found in {zip_path}")
|
||||
return None
|
||||
|
||||
# Use the first HTML file found
|
||||
html_file = html_files[0]
|
||||
print(f"Found HTML file: {html_file}")
|
||||
|
||||
with zip_file.open(html_file) as f:
|
||||
return f.read().decode("utf-8", errors="ignore")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting HTML from zip {zip_path}: {e}")
|
||||
return None
|
||||
|
||||
def _parse_chatgpt_html(self, html_content: str) -> list[dict]:
|
||||
"""
|
||||
Parse ChatGPT HTML export to extract conversations.
|
||||
|
||||
Args:
|
||||
html_content: HTML content from ChatGPT export
|
||||
|
||||
Returns:
|
||||
List of conversation dictionaries
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
conversations = []
|
||||
|
||||
# Try different possible structures for ChatGPT exports
|
||||
# Structure 1: Look for conversation containers
|
||||
conversation_containers = soup.find_all(
|
||||
["div", "section"], class_=re.compile(r"conversation|chat", re.I)
|
||||
)
|
||||
|
||||
if not conversation_containers:
|
||||
# Structure 2: Look for message containers directly
|
||||
conversation_containers = [soup] # Use the entire document as one conversation
|
||||
|
||||
for container in conversation_containers:
|
||||
conversation = self._extract_conversation_from_container(container)
|
||||
if conversation and conversation.get("messages"):
|
||||
conversations.append(conversation)
|
||||
|
||||
# If no structured conversations found, try to extract all text as one conversation
|
||||
if not conversations:
|
||||
all_text = soup.get_text(separator="\n", strip=True)
|
||||
if all_text:
|
||||
conversations.append(
|
||||
{
|
||||
"title": "ChatGPT Conversation",
|
||||
"messages": [{"role": "mixed", "content": all_text, "timestamp": None}],
|
||||
"timestamp": None,
|
||||
}
|
||||
)
|
||||
|
||||
return conversations
|
||||
|
||||
def _extract_conversation_from_container(self, container) -> dict | None:
|
||||
"""
|
||||
Extract conversation data from a container element.
|
||||
|
||||
Args:
|
||||
container: BeautifulSoup element containing conversation
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation data or None
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Look for message elements with various possible structures
|
||||
message_selectors = ['[class*="message"]', '[class*="chat"]', "[data-message]", "p", "div"]
|
||||
|
||||
for selector in message_selectors:
|
||||
message_elements = container.select(selector)
|
||||
if message_elements:
|
||||
break
|
||||
else:
|
||||
message_elements = []
|
||||
|
||||
# If no structured messages found, treat the entire container as one message
|
||||
if not message_elements:
|
||||
text_content = container.get_text(separator="\n", strip=True)
|
||||
if text_content:
|
||||
messages.append({"role": "mixed", "content": text_content, "timestamp": None})
|
||||
else:
|
||||
for element in message_elements:
|
||||
message = self._extract_message_from_element(element)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Try to extract conversation title
|
||||
title_element = container.find(["h1", "h2", "h3", "title"])
|
||||
title = title_element.get_text(strip=True) if title_element else "ChatGPT Conversation"
|
||||
|
||||
# Try to extract timestamp from various possible locations
|
||||
timestamp = self._extract_timestamp_from_container(container)
|
||||
|
||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||
|
||||
def _extract_message_from_element(self, element) -> dict | None:
|
||||
"""
|
||||
Extract message data from an element.
|
||||
|
||||
Args:
|
||||
element: BeautifulSoup element containing message
|
||||
|
||||
Returns:
|
||||
Dictionary with message data or None
|
||||
"""
|
||||
text_content = element.get_text(separator=" ", strip=True)
|
||||
|
||||
# Skip empty or very short messages
|
||||
if not text_content or len(text_content.strip()) < 3:
|
||||
return None
|
||||
|
||||
# Try to determine role (user/assistant) from class names or content
|
||||
role = "mixed" # Default role
|
||||
|
||||
class_names = " ".join(element.get("class", [])).lower()
|
||||
if "user" in class_names or "human" in class_names:
|
||||
role = "user"
|
||||
elif "assistant" in class_names or "ai" in class_names or "gpt" in class_names:
|
||||
role = "assistant"
|
||||
elif text_content.lower().startswith(("you:", "user:", "me:")):
|
||||
role = "user"
|
||||
text_content = re.sub(r"^(you|user|me):\s*", "", text_content, flags=re.IGNORECASE)
|
||||
elif text_content.lower().startswith(("chatgpt:", "assistant:", "ai:")):
|
||||
role = "assistant"
|
||||
text_content = re.sub(
|
||||
r"^(chatgpt|assistant|ai):\s*", "", text_content, flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# Try to extract timestamp
|
||||
timestamp = self._extract_timestamp_from_element(element)
|
||||
|
||||
return {"role": role, "content": text_content, "timestamp": timestamp}
|
||||
|
||||
def _extract_timestamp_from_element(self, element) -> str | None:
|
||||
"""Extract timestamp from element."""
|
||||
# Look for timestamp in various attributes and child elements
|
||||
timestamp_attrs = ["data-timestamp", "timestamp", "datetime"]
|
||||
for attr in timestamp_attrs:
|
||||
if element.get(attr):
|
||||
return element.get(attr)
|
||||
|
||||
# Look for time elements
|
||||
time_element = element.find("time")
|
||||
if time_element:
|
||||
return time_element.get("datetime") or time_element.get_text(strip=True)
|
||||
|
||||
# Look for date-like text patterns
|
||||
text = element.get_text()
|
||||
date_patterns = [r"\d{4}-\d{2}-\d{2}", r"\d{1,2}/\d{1,2}/\d{4}", r"\w+ \d{1,2}, \d{4}"]
|
||||
|
||||
for pattern in date_patterns:
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
def _extract_timestamp_from_container(self, container) -> str | None:
|
||||
"""Extract timestamp from conversation container."""
|
||||
return self._extract_timestamp_from_element(container)
|
||||
|
||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||
"""
|
||||
Create concatenated content from conversation messages.
|
||||
|
||||
Args:
|
||||
conversation: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
title = conversation.get("title", "ChatGPT Conversation")
|
||||
messages = conversation.get("messages", [])
|
||||
timestamp = conversation.get("timestamp", "Unknown")
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if role == "user":
|
||||
prefix = "[You]"
|
||||
elif role == "assistant":
|
||||
prefix = "[ChatGPT]"
|
||||
else:
|
||||
prefix = "[Message]"
|
||||
|
||||
# Add timestamp if available
|
||||
if msg_timestamp:
|
||||
prefix += f" ({msg_timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {content}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""Conversation: {title}
|
||||
Date: {timestamp}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load ChatGPT export data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing ChatGPT export files or path to specific file
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum number of conversations to process
|
||||
chatgpt_export_path (str): Specific path to ChatGPT export file/directory
|
||||
include_metadata (bool): Whether to include metadata in documents
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", -1)
|
||||
chatgpt_export_path = load_kwargs.get("chatgpt_export_path", input_dir)
|
||||
include_metadata = load_kwargs.get("include_metadata", True)
|
||||
|
||||
if not chatgpt_export_path:
|
||||
print("No ChatGPT export path provided")
|
||||
return docs
|
||||
|
||||
export_path = Path(chatgpt_export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"ChatGPT export path not found: {export_path}")
|
||||
return docs
|
||||
|
||||
html_content = None
|
||||
|
||||
# Handle different input types
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() == ".zip":
|
||||
# Extract HTML from zip file
|
||||
html_content = self._extract_html_from_zip(export_path)
|
||||
elif export_path.suffix.lower() == ".html":
|
||||
# Read HTML file directly
|
||||
try:
|
||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||
html_content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading HTML file {export_path}: {e}")
|
||||
return docs
|
||||
else:
|
||||
print(f"Unsupported file type: {export_path.suffix}")
|
||||
return docs
|
||||
|
||||
elif export_path.is_dir():
|
||||
# Look for HTML files in directory
|
||||
html_files = list(export_path.glob("*.html"))
|
||||
zip_files = list(export_path.glob("*.zip"))
|
||||
|
||||
if html_files:
|
||||
# Use first HTML file found
|
||||
html_file = html_files[0]
|
||||
print(f"Found HTML file: {html_file}")
|
||||
try:
|
||||
with open(html_file, encoding="utf-8", errors="ignore") as f:
|
||||
html_content = f.read()
|
||||
except Exception as e:
|
||||
print(f"Error reading HTML file {html_file}: {e}")
|
||||
return docs
|
||||
|
||||
elif zip_files:
|
||||
# Use first zip file found
|
||||
zip_file = zip_files[0]
|
||||
print(f"Found zip file: {zip_file}")
|
||||
html_content = self._extract_html_from_zip(zip_file)
|
||||
|
||||
else:
|
||||
print(f"No HTML or zip files found in {export_path}")
|
||||
return docs
|
||||
|
||||
if not html_content:
|
||||
print("No HTML content found to process")
|
||||
return docs
|
||||
|
||||
# Parse conversations from HTML
|
||||
print("Parsing ChatGPT conversations from HTML...")
|
||||
conversations = self._parse_chatgpt_html(html_content)
|
||||
|
||||
if not conversations:
|
||||
print("No conversations found in HTML content")
|
||||
return docs
|
||||
|
||||
print(f"Found {len(conversations)} conversations")
|
||||
|
||||
# Process conversations into documents
|
||||
count = 0
|
||||
for conversation in conversations:
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Create one document per conversation with concatenated messages
|
||||
doc_content = self._create_concatenated_content(conversation)
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"title": conversation.get("title", "ChatGPT Conversation"),
|
||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||
"message_count": len(conversation.get("messages", [])),
|
||||
"source": "ChatGPT Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
else:
|
||||
# Create separate documents for each message
|
||||
for message in conversation.get("messages", []):
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# Create document content with context
|
||||
doc_content = f"""Conversation: {conversation.get("title", "ChatGPT Conversation")}
|
||||
Role: {role}
|
||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||
Message: {content}
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"conversation_title": conversation.get("title", "ChatGPT Conversation"),
|
||||
"role": role,
|
||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||
"source": "ChatGPT Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(docs)} documents from ChatGPT export")
|
||||
return docs
|
||||
187
apps/chatgpt_rag.py
Normal file
187
apps/chatgpt_rag.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
ChatGPT RAG example using the unified interface.
|
||||
Supports ChatGPT export data from chat.html files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .chatgpt_data.chatgpt_reader import ChatGPTReader
|
||||
|
||||
|
||||
class ChatGPTRAG(BaseRAGExample):
|
||||
"""RAG example for ChatGPT conversation data."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all conversations by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="ChatGPT",
|
||||
description="Process and query ChatGPT conversation exports with LEANN",
|
||||
default_index_name="chatgpt_conversations_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add ChatGPT-specific arguments."""
|
||||
chatgpt_group = parser.add_argument_group("ChatGPT Parameters")
|
||||
chatgpt_group.add_argument(
|
||||
"--export-path",
|
||||
type=str,
|
||||
default="./chatgpt_export",
|
||||
help="Path to ChatGPT export file (.zip or .html) or directory containing exports (default: ./chatgpt_export)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--separate-messages",
|
||||
action="store_true",
|
||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||
)
|
||||
chatgpt_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _find_chatgpt_exports(self, export_path: Path) -> list[Path]:
|
||||
"""
|
||||
Find ChatGPT export files in the given path.
|
||||
|
||||
Args:
|
||||
export_path: Path to search for exports
|
||||
|
||||
Returns:
|
||||
List of paths to ChatGPT export files
|
||||
"""
|
||||
export_files = []
|
||||
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() in [".zip", ".html"]:
|
||||
export_files.append(export_path)
|
||||
elif export_path.is_dir():
|
||||
# Look for zip and html files
|
||||
export_files.extend(export_path.glob("*.zip"))
|
||||
export_files.extend(export_path.glob("*.html"))
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load ChatGPT export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"ChatGPT export path not found: {export_path}")
|
||||
print(
|
||||
"Please ensure you have exported your ChatGPT data and placed it in the correct location."
|
||||
)
|
||||
print("\nTo export your ChatGPT data:")
|
||||
print("1. Sign in to ChatGPT")
|
||||
print("2. Click on your profile icon → Settings → Data Controls")
|
||||
print("3. Click 'Export' under Export Data")
|
||||
print("4. Download the zip file from the email link")
|
||||
print("5. Extract or place the file/directory at the specified path")
|
||||
return []
|
||||
|
||||
# Find export files
|
||||
export_files = self._find_chatgpt_exports(export_path)
|
||||
|
||||
if not export_files:
|
||||
print(f"No ChatGPT export files (.zip or .html) found in: {export_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(export_files)} ChatGPT export files")
|
||||
|
||||
# Create reader with appropriate settings
|
||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||
reader = ChatGPTReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Process each export file
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_file in enumerate(export_files):
|
||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per file
|
||||
max_per_file = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_file = remaining
|
||||
|
||||
# Load conversations
|
||||
documents = reader.load_data(
|
||||
chatgpt_export_path=str(export_file),
|
||||
max_count=max_per_file,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} conversations from this file")
|
||||
else:
|
||||
print(f"No conversations loaded from {export_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_file}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No conversations found to process!")
|
||||
print("\nTroubleshooting:")
|
||||
print("- Ensure the export file is a valid ChatGPT export")
|
||||
print("- Check that the HTML file contains conversation data")
|
||||
print("- Try extracting the zip file and pointing to the HTML file directly")
|
||||
return []
|
||||
|
||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||
print("Now starting to split into text chunks... this may take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for ChatGPT RAG
|
||||
print("\n🤖 ChatGPT RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did I ask about Python programming?'")
|
||||
print("- 'Show me conversations about machine learning'")
|
||||
print("- 'Find discussions about travel planning'")
|
||||
print("- 'What advice did ChatGPT give me about career development?'")
|
||||
print("- 'Search for conversations about cooking recipes'")
|
||||
print("\nTo get started:")
|
||||
print("1. Export your ChatGPT data from Settings → Data Controls → Export")
|
||||
print("2. Place the downloaded zip file or extracted HTML in ./chatgpt_export/")
|
||||
print("3. Run this script to build your personal ChatGPT knowledge base!")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = ChatGPTRAG()
|
||||
asyncio.run(rag.run())
|
||||
47
apps/chunking/__init__.py
Normal file
47
apps/chunking/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Unified chunking utilities facade.
|
||||
|
||||
This module re-exports the packaged utilities from `leann.chunking_utils` so
|
||||
that both repo apps (importing `chunking`) and installed wheels share one
|
||||
single implementation. When running from the repo without installation, it
|
||||
adds the `packages/leann-core/src` directory to `sys.path` as a fallback.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
detect_code_files,
|
||||
get_language_from_extension,
|
||||
)
|
||||
except Exception: # pragma: no cover - best-effort fallback for dev environment
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
leann_src = repo_root / "packages" / "leann-core" / "src"
|
||||
if leann_src.exists():
|
||||
sys.path.insert(0, str(leann_src))
|
||||
from leann.chunking_utils import (
|
||||
CODE_EXTENSIONS,
|
||||
_traditional_chunks_as_dicts,
|
||||
create_ast_chunks,
|
||||
create_text_chunks,
|
||||
create_traditional_chunks,
|
||||
detect_code_files,
|
||||
get_language_from_extension,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
__all__ = [
|
||||
"CODE_EXTENSIONS",
|
||||
"_traditional_chunks_as_dicts",
|
||||
"create_ast_chunks",
|
||||
"create_text_chunks",
|
||||
"create_traditional_chunks",
|
||||
"detect_code_files",
|
||||
"get_language_from_extension",
|
||||
]
|
||||
0
apps/claude_data/__init__.py
Normal file
0
apps/claude_data/__init__.py
Normal file
420
apps/claude_data/claude_reader.py
Normal file
420
apps/claude_data/claude_reader.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Claude export data reader.
|
||||
|
||||
Reads and processes Claude conversation data from exported JSON files.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ClaudeReader(BaseReader):
|
||||
"""
|
||||
Claude export data reader.
|
||||
|
||||
Reads Claude conversation data from exported JSON files or zip archives.
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _extract_json_from_zip(self, zip_path: Path) -> list[str]:
|
||||
"""
|
||||
Extract JSON files from Claude export zip file.
|
||||
|
||||
Args:
|
||||
zip_path: Path to the Claude export zip file
|
||||
|
||||
Returns:
|
||||
List of JSON content strings, or empty list if not found
|
||||
"""
|
||||
json_contents = []
|
||||
try:
|
||||
with ZipFile(zip_path, "r") as zip_file:
|
||||
# Look for JSON files
|
||||
json_files = [f for f in zip_file.namelist() if f.endswith(".json")]
|
||||
|
||||
if not json_files:
|
||||
print(f"No JSON files found in {zip_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(json_files)} JSON files in archive")
|
||||
|
||||
for json_file in json_files:
|
||||
with zip_file.open(json_file) as f:
|
||||
content = f.read().decode("utf-8", errors="ignore")
|
||||
json_contents.append(content)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting JSON from zip {zip_path}: {e}")
|
||||
|
||||
return json_contents
|
||||
|
||||
def _parse_claude_json(self, json_content: str) -> list[dict]:
|
||||
"""
|
||||
Parse Claude JSON export to extract conversations.
|
||||
|
||||
Args:
|
||||
json_content: JSON content from Claude export
|
||||
|
||||
Returns:
|
||||
List of conversation dictionaries
|
||||
"""
|
||||
try:
|
||||
data = json.loads(json_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON: {e}")
|
||||
return []
|
||||
|
||||
conversations = []
|
||||
|
||||
# Handle different possible JSON structures
|
||||
if isinstance(data, list):
|
||||
# If data is a list of conversations
|
||||
for item in data:
|
||||
conversation = self._extract_conversation_from_json(item)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
elif isinstance(data, dict):
|
||||
# Check for common structures
|
||||
if "conversations" in data:
|
||||
# Structure: {"conversations": [...]}
|
||||
for item in data["conversations"]:
|
||||
conversation = self._extract_conversation_from_json(item)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
elif "messages" in data:
|
||||
# Single conversation with messages
|
||||
conversation = self._extract_conversation_from_json(data)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
else:
|
||||
# Try to treat the whole object as a conversation
|
||||
conversation = self._extract_conversation_from_json(data)
|
||||
if conversation:
|
||||
conversations.append(conversation)
|
||||
|
||||
return conversations
|
||||
|
||||
def _extract_conversation_from_json(self, conv_data: dict) -> dict | None:
|
||||
"""
|
||||
Extract conversation data from a JSON object.
|
||||
|
||||
Args:
|
||||
conv_data: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Dictionary with conversation data or None
|
||||
"""
|
||||
if not isinstance(conv_data, dict):
|
||||
return None
|
||||
|
||||
messages = []
|
||||
|
||||
# Look for messages in various possible structures
|
||||
message_sources = []
|
||||
if "messages" in conv_data:
|
||||
message_sources = conv_data["messages"]
|
||||
elif "chat" in conv_data:
|
||||
message_sources = conv_data["chat"]
|
||||
elif "conversation" in conv_data:
|
||||
message_sources = conv_data["conversation"]
|
||||
else:
|
||||
# If no clear message structure, try to extract from the object itself
|
||||
if "content" in conv_data and "role" in conv_data:
|
||||
message_sources = [conv_data]
|
||||
|
||||
for msg_data in message_sources:
|
||||
message = self._extract_message_from_json(msg_data)
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Extract conversation metadata
|
||||
title = self._extract_title_from_conversation(conv_data, messages)
|
||||
timestamp = self._extract_timestamp_from_conversation(conv_data)
|
||||
|
||||
return {"title": title, "messages": messages, "timestamp": timestamp}
|
||||
|
||||
def _extract_message_from_json(self, msg_data: dict) -> dict | None:
|
||||
"""
|
||||
Extract message data from a JSON message object.
|
||||
|
||||
Args:
|
||||
msg_data: Dictionary containing message data
|
||||
|
||||
Returns:
|
||||
Dictionary with message data or None
|
||||
"""
|
||||
if not isinstance(msg_data, dict):
|
||||
return None
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = ""
|
||||
content_fields = ["content", "text", "message", "body"]
|
||||
for field in content_fields:
|
||||
if msg_data.get(field):
|
||||
content = str(msg_data[field])
|
||||
break
|
||||
|
||||
if not content or len(content.strip()) < 3:
|
||||
return None
|
||||
|
||||
# Extract role (user/assistant/human/ai/claude)
|
||||
role = "mixed" # Default role
|
||||
role_fields = ["role", "sender", "from", "author", "type"]
|
||||
for field in role_fields:
|
||||
if msg_data.get(field):
|
||||
role_value = str(msg_data[field]).lower()
|
||||
if role_value in ["user", "human", "person"]:
|
||||
role = "user"
|
||||
elif role_value in ["assistant", "ai", "claude", "bot"]:
|
||||
role = "assistant"
|
||||
break
|
||||
|
||||
# Extract timestamp
|
||||
timestamp = self._extract_timestamp_from_message(msg_data)
|
||||
|
||||
return {"role": role, "content": content, "timestamp": timestamp}
|
||||
|
||||
def _extract_timestamp_from_message(self, msg_data: dict) -> str | None:
|
||||
"""Extract timestamp from message data."""
|
||||
timestamp_fields = ["timestamp", "created_at", "date", "time"]
|
||||
for field in timestamp_fields:
|
||||
if msg_data.get(field):
|
||||
return str(msg_data[field])
|
||||
return None
|
||||
|
||||
def _extract_timestamp_from_conversation(self, conv_data: dict) -> str | None:
|
||||
"""Extract timestamp from conversation data."""
|
||||
timestamp_fields = ["timestamp", "created_at", "date", "updated_at", "last_updated"]
|
||||
for field in timestamp_fields:
|
||||
if conv_data.get(field):
|
||||
return str(conv_data[field])
|
||||
return None
|
||||
|
||||
def _extract_title_from_conversation(self, conv_data: dict, messages: list) -> str:
|
||||
"""Extract or generate title for conversation."""
|
||||
# Try to find explicit title
|
||||
title_fields = ["title", "name", "subject", "topic"]
|
||||
for field in title_fields:
|
||||
if conv_data.get(field):
|
||||
return str(conv_data[field])
|
||||
|
||||
# Generate title from first user message
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content = message.get("content", "")
|
||||
if content:
|
||||
# Use first 50 characters as title
|
||||
title = content[:50].strip()
|
||||
if len(content) > 50:
|
||||
title += "..."
|
||||
return title
|
||||
|
||||
return "Claude Conversation"
|
||||
|
||||
def _create_concatenated_content(self, conversation: dict) -> str:
|
||||
"""
|
||||
Create concatenated content from conversation messages.
|
||||
|
||||
Args:
|
||||
conversation: Dictionary containing conversation data
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
title = conversation.get("title", "Claude Conversation")
|
||||
messages = conversation.get("messages", [])
|
||||
timestamp = conversation.get("timestamp", "Unknown")
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if role == "user":
|
||||
prefix = "[You]"
|
||||
elif role == "assistant":
|
||||
prefix = "[Claude]"
|
||||
else:
|
||||
prefix = "[Message]"
|
||||
|
||||
# Add timestamp if available
|
||||
if msg_timestamp:
|
||||
prefix += f" ({msg_timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {content}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""Conversation: {title}
|
||||
Date: {timestamp}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load Claude export data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing Claude export files or path to specific file
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum number of conversations to process
|
||||
claude_export_path (str): Specific path to Claude export file/directory
|
||||
include_metadata (bool): Whether to include metadata in documents
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", -1)
|
||||
claude_export_path = load_kwargs.get("claude_export_path", input_dir)
|
||||
include_metadata = load_kwargs.get("include_metadata", True)
|
||||
|
||||
if not claude_export_path:
|
||||
print("No Claude export path provided")
|
||||
return docs
|
||||
|
||||
export_path = Path(claude_export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"Claude export path not found: {export_path}")
|
||||
return docs
|
||||
|
||||
json_contents = []
|
||||
|
||||
# Handle different input types
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() == ".zip":
|
||||
# Extract JSON from zip file
|
||||
json_contents = self._extract_json_from_zip(export_path)
|
||||
elif export_path.suffix.lower() == ".json":
|
||||
# Read JSON file directly
|
||||
try:
|
||||
with open(export_path, encoding="utf-8", errors="ignore") as f:
|
||||
json_contents.append(f.read())
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {export_path}: {e}")
|
||||
return docs
|
||||
else:
|
||||
print(f"Unsupported file type: {export_path.suffix}")
|
||||
return docs
|
||||
|
||||
elif export_path.is_dir():
|
||||
# Look for JSON files in directory
|
||||
json_files = list(export_path.glob("*.json"))
|
||||
zip_files = list(export_path.glob("*.zip"))
|
||||
|
||||
if json_files:
|
||||
print(f"Found {len(json_files)} JSON files in directory")
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, encoding="utf-8", errors="ignore") as f:
|
||||
json_contents.append(f.read())
|
||||
except Exception as e:
|
||||
print(f"Error reading JSON file {json_file}: {e}")
|
||||
continue
|
||||
|
||||
if zip_files:
|
||||
print(f"Found {len(zip_files)} ZIP files in directory")
|
||||
for zip_file in zip_files:
|
||||
zip_contents = self._extract_json_from_zip(zip_file)
|
||||
json_contents.extend(zip_contents)
|
||||
|
||||
if not json_files and not zip_files:
|
||||
print(f"No JSON or ZIP files found in {export_path}")
|
||||
return docs
|
||||
|
||||
if not json_contents:
|
||||
print("No JSON content found to process")
|
||||
return docs
|
||||
|
||||
# Parse conversations from JSON content
|
||||
print("Parsing Claude conversations from JSON...")
|
||||
all_conversations = []
|
||||
for json_content in json_contents:
|
||||
conversations = self._parse_claude_json(json_content)
|
||||
all_conversations.extend(conversations)
|
||||
|
||||
if not all_conversations:
|
||||
print("No conversations found in JSON content")
|
||||
return docs
|
||||
|
||||
print(f"Found {len(all_conversations)} conversations")
|
||||
|
||||
# Process conversations into documents
|
||||
count = 0
|
||||
for conversation in all_conversations:
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Create one document per conversation with concatenated messages
|
||||
doc_content = self._create_concatenated_content(conversation)
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"title": conversation.get("title", "Claude Conversation"),
|
||||
"timestamp": conversation.get("timestamp", "Unknown"),
|
||||
"message_count": len(conversation.get("messages", [])),
|
||||
"source": "Claude Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
else:
|
||||
# Create separate documents for each message
|
||||
for message in conversation.get("messages", []):
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
role = message.get("role", "mixed")
|
||||
content = message.get("content", "")
|
||||
msg_timestamp = message.get("timestamp", "")
|
||||
|
||||
if not content.strip():
|
||||
continue
|
||||
|
||||
# Create document content with context
|
||||
doc_content = f"""Conversation: {conversation.get("title", "Claude Conversation")}
|
||||
Role: {role}
|
||||
Timestamp: {msg_timestamp or conversation.get("timestamp", "Unknown")}
|
||||
Message: {content}
|
||||
"""
|
||||
|
||||
metadata = {}
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"conversation_title": conversation.get("title", "Claude Conversation"),
|
||||
"role": role,
|
||||
"timestamp": msg_timestamp or conversation.get("timestamp", "Unknown"),
|
||||
"source": "Claude Export",
|
||||
}
|
||||
|
||||
doc = Document(text=doc_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(docs)} documents from Claude export")
|
||||
return docs
|
||||
190
apps/claude_rag.py
Normal file
190
apps/claude_rag.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Claude RAG example using the unified interface.
|
||||
Supports Claude export data from JSON files.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .claude_data.claude_reader import ClaudeReader
|
||||
|
||||
|
||||
class ClaudeRAG(BaseRAGExample):
|
||||
"""RAG example for Claude conversation data."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all conversations by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Claude",
|
||||
description="Process and query Claude conversation exports with LEANN",
|
||||
default_index_name="claude_conversations_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add Claude-specific arguments."""
|
||||
claude_group = parser.add_argument_group("Claude Parameters")
|
||||
claude_group.add_argument(
|
||||
"--export-path",
|
||||
type=str,
|
||||
default="./claude_export",
|
||||
help="Path to Claude export file (.json or .zip) or directory containing exports (default: ./claude_export)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--separate-messages",
|
||||
action="store_true",
|
||||
help="Process each message as a separate document (overrides --concatenate-conversations)",
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--chunk-size", type=int, default=512, help="Text chunk size (default: 512)"
|
||||
)
|
||||
claude_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
|
||||
def _find_claude_exports(self, export_path: Path) -> list[Path]:
|
||||
"""
|
||||
Find Claude export files in the given path.
|
||||
|
||||
Args:
|
||||
export_path: Path to search for exports
|
||||
|
||||
Returns:
|
||||
List of paths to Claude export files
|
||||
"""
|
||||
export_files = []
|
||||
|
||||
if export_path.is_file():
|
||||
if export_path.suffix.lower() in [".zip", ".json"]:
|
||||
export_files.append(export_path)
|
||||
elif export_path.is_dir():
|
||||
# Look for zip and json files
|
||||
export_files.extend(export_path.glob("*.zip"))
|
||||
export_files.extend(export_path.glob("*.json"))
|
||||
|
||||
return export_files
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Claude export data and convert to text chunks."""
|
||||
export_path = Path(args.export_path)
|
||||
|
||||
if not export_path.exists():
|
||||
print(f"Claude export path not found: {export_path}")
|
||||
print(
|
||||
"Please ensure you have exported your Claude data and placed it in the correct location."
|
||||
)
|
||||
print("\nTo export your Claude data:")
|
||||
print("1. Open Claude in your browser")
|
||||
print("2. Look for export/download options in settings or conversation menu")
|
||||
print("3. Download the conversation data (usually in JSON format)")
|
||||
print("4. Place the file/directory at the specified path")
|
||||
print(
|
||||
"\nNote: Claude export methods may vary. Check Claude's help documentation for current instructions."
|
||||
)
|
||||
return []
|
||||
|
||||
# Find export files
|
||||
export_files = self._find_claude_exports(export_path)
|
||||
|
||||
if not export_files:
|
||||
print(f"No Claude export files (.json or .zip) found in: {export_path}")
|
||||
return []
|
||||
|
||||
print(f"Found {len(export_files)} Claude export files")
|
||||
|
||||
# Create reader with appropriate settings
|
||||
concatenate = args.concatenate_conversations and not args.separate_messages
|
||||
reader = ClaudeReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Process each export file
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_file in enumerate(export_files):
|
||||
print(f"\nProcessing export file {i + 1}/{len(export_files)}: {export_file.name}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per file
|
||||
max_per_file = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_file = remaining
|
||||
|
||||
# Load conversations
|
||||
documents = reader.load_data(
|
||||
claude_export_path=str(export_file),
|
||||
max_count=max_per_file,
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} conversations from this file")
|
||||
else:
|
||||
print(f"No conversations loaded from {export_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_file}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No conversations found to process!")
|
||||
print("\nTroubleshooting:")
|
||||
print("- Ensure the export file is a valid Claude export")
|
||||
print("- Check that the JSON file contains conversation data")
|
||||
print("- Try using a different export format or method")
|
||||
print("- Check Claude's documentation for current export procedures")
|
||||
return []
|
||||
|
||||
print(f"\nTotal conversations processed: {len(all_documents)}")
|
||||
print("Now starting to split into text chunks... this may take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} conversations")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for Claude RAG
|
||||
print("\n🤖 Claude RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did I ask Claude about Python programming?'")
|
||||
print("- 'Show me conversations about machine learning'")
|
||||
print("- 'Find discussions about code optimization'")
|
||||
print("- 'What advice did Claude give me about software design?'")
|
||||
print("- 'Search for conversations about debugging techniques'")
|
||||
print("\nTo get started:")
|
||||
print("1. Export your Claude conversation data")
|
||||
print("2. Place the JSON/ZIP file in ./claude_export/")
|
||||
print("3. Run this script to build your personal Claude knowledge base!")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = ClaudeRAG()
|
||||
asyncio.run(rag.run())
|
||||
207
apps/code_rag.py
Normal file
207
apps/code_rag.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Code RAG example using AST-aware chunking for optimal code understanding.
|
||||
Specialized for code repositories with automatic language detection and
|
||||
optimized chunking parameters.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import CODE_EXTENSIONS, create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class CodeRAG(BaseRAGExample):
|
||||
"""Specialized RAG example for code repositories with AST-aware chunking."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Code",
|
||||
description="Process and query code repositories with AST-aware chunking",
|
||||
default_index_name="code_index",
|
||||
)
|
||||
# Override defaults for code-specific usage
|
||||
self.embedding_model_default = "facebook/contriever" # Good for code
|
||||
self.max_items_default = -1 # Process all code files by default
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add code-specific arguments."""
|
||||
code_group = parser.add_argument_group("Code Repository Parameters")
|
||||
|
||||
code_group.add_argument(
|
||||
"--repo-dir",
|
||||
type=str,
|
||||
default=".",
|
||||
help="Code repository directory to index (default: current directory)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--include-extensions",
|
||||
nargs="+",
|
||||
default=list(CODE_EXTENSIONS.keys()),
|
||||
help="File extensions to include (default: supported code extensions)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--exclude-dirs",
|
||||
nargs="+",
|
||||
default=[
|
||||
".git",
|
||||
"__pycache__",
|
||||
"node_modules",
|
||||
"venv",
|
||||
".venv",
|
||||
"build",
|
||||
"dist",
|
||||
"target",
|
||||
],
|
||||
help="Directories to exclude from indexing",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--max-file-size",
|
||||
type=int,
|
||||
default=1000000, # 1MB
|
||||
help="Maximum file size in bytes to process (default: 1MB)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--include-comments",
|
||||
action="store_true",
|
||||
help="Include comments in chunking (useful for documentation)",
|
||||
)
|
||||
code_group.add_argument(
|
||||
"--preserve-imports",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Try to preserve import statements in chunks (default: True)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load code files and convert to AST-aware chunks."""
|
||||
print(f"🔍 Scanning code repository: {args.repo_dir}")
|
||||
print(f"📁 Including extensions: {args.include_extensions}")
|
||||
print(f"🚫 Excluding directories: {args.exclude_dirs}")
|
||||
|
||||
# Check if repository directory exists
|
||||
repo_path = Path(args.repo_dir)
|
||||
if not repo_path.exists():
|
||||
raise ValueError(f"Repository directory not found: {args.repo_dir}")
|
||||
|
||||
# Create exclusion filter
|
||||
def file_filter(file_path: str) -> bool:
|
||||
"""Filter out unwanted files and directories."""
|
||||
path = Path(file_path)
|
||||
|
||||
# Check file size
|
||||
try:
|
||||
if path.stat().st_size > args.max_file_size:
|
||||
print(f"⚠️ Skipping large file: {path.name} ({path.stat().st_size} bytes)")
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# Check if in excluded directory
|
||||
for exclude_dir in args.exclude_dirs:
|
||||
if exclude_dir in path.parts:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
try:
|
||||
# Load documents with file filtering
|
||||
documents = SimpleDirectoryReader(
|
||||
args.repo_dir,
|
||||
file_extractor=None,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=args.include_extensions,
|
||||
exclude_hidden=True,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
# Apply custom filtering
|
||||
filtered_docs = []
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_filter(file_path):
|
||||
filtered_docs.append(doc)
|
||||
|
||||
documents = filtered_docs
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading code files: {e}")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
print(
|
||||
f"❌ No code files found in {args.repo_dir} with extensions {args.include_extensions}"
|
||||
)
|
||||
return []
|
||||
|
||||
print(f"✅ Loaded {len(documents)} code files")
|
||||
|
||||
# Show breakdown by language/extension
|
||||
ext_counts = {}
|
||||
for doc in documents:
|
||||
file_path = doc.metadata.get("file_path", "")
|
||||
if file_path:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
ext_counts[ext] = ext_counts.get(ext, 0) + 1
|
||||
|
||||
print("📊 Files by extension:")
|
||||
for ext, count in sorted(ext_counts.items()):
|
||||
print(f" {ext}: {count} files")
|
||||
|
||||
# Use AST-aware chunking by default for code
|
||||
print(
|
||||
f"🧠 Using AST-aware chunking (chunk_size: {args.ast_chunk_size}, overlap: {args.ast_chunk_overlap})"
|
||||
)
|
||||
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=256, # Fallback for non-code files
|
||||
chunk_overlap=64,
|
||||
use_ast_chunking=True, # Always use AST for code RAG
|
||||
ast_chunk_size=args.ast_chunk_size,
|
||||
ast_chunk_overlap=args.ast_chunk_overlap,
|
||||
code_file_extensions=args.include_extensions,
|
||||
ast_fallback_traditional=True,
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"⏳ Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
print(f"✅ Generated {len(all_texts)} code chunks")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for code RAG
|
||||
print("\n💻 Code RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'How does the embedding computation work?'")
|
||||
print("- 'What are the main classes in this codebase?'")
|
||||
print("- 'Show me the search implementation'")
|
||||
print("- 'How is error handling implemented?'")
|
||||
print("- 'What design patterns are used?'")
|
||||
print("- 'Explain the chunking logic'")
|
||||
print("\n🚀 Features:")
|
||||
print("- ✅ AST-aware chunking preserves code structure")
|
||||
print("- ✅ Automatic language detection")
|
||||
print("- ✅ Smart filtering of large files and common excludes")
|
||||
print("- ✅ Optimized for code understanding")
|
||||
print("\nUsage examples:")
|
||||
print(" python -m apps.code_rag --repo-dir ./my_project")
|
||||
print(
|
||||
" python -m apps.code_rag --include-extensions .py .js --query 'How does authentication work?'"
|
||||
)
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = CodeRAG()
|
||||
asyncio.run(rag.run())
|
||||
364
apps/colqwen_rag.py
Normal file
364
apps/colqwen_rag.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ColQwen RAG - Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali
|
||||
|
||||
Usage:
|
||||
python -m apps.colqwen_rag build --pdfs ./my_pdfs/ --index my_index
|
||||
python -m apps.colqwen_rag search my_index "How does attention work?"
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
# Add LEANN packages to path
|
||||
_repo_root = Path(__file__).resolve().parents[1]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.append(str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.append(str(_leann_hnsw_pkg))
|
||||
|
||||
import torch # noqa: E402
|
||||
from colpali_engine import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor # noqa: E402
|
||||
from colpali_engine.utils.torch_utils import ListDataset # noqa: E402
|
||||
from pdf2image import convert_from_path # noqa: E402
|
||||
from PIL import Image # noqa: E402
|
||||
from torch.utils.data import DataLoader # noqa: E402
|
||||
from tqdm import tqdm # noqa: E402
|
||||
|
||||
# Import the existing multi-vector implementation
|
||||
sys.path.append(str(_repo_root / "apps" / "multimodal" / "vision-based-pdf-multi-vector"))
|
||||
from leann_multi_vector import LeannMultiVector # noqa: E402
|
||||
|
||||
|
||||
class ColQwenRAG:
|
||||
"""Easy-to-use ColQwen RAG system for multimodal PDF retrieval."""
|
||||
|
||||
def __init__(self, model_type: str = "colpali"):
|
||||
"""
|
||||
Initialize ColQwen RAG system.
|
||||
|
||||
Args:
|
||||
model_type: "colqwen2" or "colpali"
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.device = self._get_device()
|
||||
# Use float32 on MPS to avoid memory issues, float16 on CUDA, bfloat16 on CPU
|
||||
if self.device.type == "mps":
|
||||
self.dtype = torch.float32
|
||||
elif self.device.type == "cuda":
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
print(f"🚀 Initializing {model_type.upper()} on {self.device} with {self.dtype}")
|
||||
|
||||
# Load model and processor with MPS-optimized settings
|
||||
try:
|
||||
if model_type == "colqwen2":
|
||||
self.model_name = "vidore/colqwen2-v1.0"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColQwen2Processor.from_pretrained(self.model_name)
|
||||
else: # colpali
|
||||
self.model_name = "vidore/colpali-v1.2"
|
||||
if self.device.type == "mps":
|
||||
# For MPS, load on CPU first then move to avoid memory allocation issues
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.model = self.model.to(self.device)
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map=self.device,
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
|
||||
except Exception as e:
|
||||
if "memory" in str(e).lower() or "offload" in str(e).lower():
|
||||
print(f"⚠️ Memory constraint on {self.device}, using CPU with optimizations...")
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
if model_type == "colqwen2":
|
||||
self.model = ColQwen2.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
self.model = ColPali.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.dtype,
|
||||
device_map="cpu",
|
||||
low_cpu_mem_usage=True,
|
||||
).eval()
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_device(self):
|
||||
"""Auto-select best available device."""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def build_index(self, pdf_paths: list[str], index_name: str, pages_dir: Optional[str] = None):
|
||||
"""
|
||||
Build multimodal index from PDF files.
|
||||
|
||||
Args:
|
||||
pdf_paths: List of PDF file paths
|
||||
index_name: Name for the index
|
||||
pages_dir: Directory to save page images (optional)
|
||||
"""
|
||||
print(f"Building index '{index_name}' from {len(pdf_paths)} PDFs...")
|
||||
|
||||
# Convert PDFs to images
|
||||
all_images = []
|
||||
all_metadata = []
|
||||
|
||||
if pages_dir:
|
||||
os.makedirs(pages_dir, exist_ok=True)
|
||||
|
||||
for pdf_path in tqdm(pdf_paths, desc="Converting PDFs"):
|
||||
try:
|
||||
images = convert_from_path(pdf_path, dpi=150)
|
||||
pdf_name = Path(pdf_path).stem
|
||||
|
||||
for i, image in enumerate(images):
|
||||
# Save image if pages_dir specified
|
||||
if pages_dir:
|
||||
image_path = Path(pages_dir) / f"{pdf_name}_page_{i + 1}.png"
|
||||
image.save(image_path)
|
||||
|
||||
all_images.append(image)
|
||||
all_metadata.append(
|
||||
{
|
||||
"pdf_path": pdf_path,
|
||||
"pdf_name": pdf_name,
|
||||
"page_number": i + 1,
|
||||
"image_path": str(image_path) if pages_dir else None,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing {pdf_path}: {e}")
|
||||
continue
|
||||
|
||||
print(f"📄 Converted {len(all_images)} pages from {len(pdf_paths)} PDFs")
|
||||
print(f"All metadata: {all_metadata}")
|
||||
|
||||
# Generate embeddings
|
||||
print("🧠 Generating embeddings...")
|
||||
embeddings = self._embed_images(all_images)
|
||||
|
||||
# Build LEANN index
|
||||
print("🔍 Building LEANN index...")
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=embeddings.shape[-1],
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Create collection and insert data
|
||||
leann_mv.create_collection()
|
||||
for i, (embedding, metadata) in enumerate(zip(embeddings, all_metadata)):
|
||||
data = {
|
||||
"doc_id": i,
|
||||
"filepath": metadata.get("image_path", ""),
|
||||
"colbert_vecs": embedding.numpy(), # Convert tensor to numpy
|
||||
}
|
||||
leann_mv.insert(data)
|
||||
|
||||
# Build the index
|
||||
leann_mv.create_index()
|
||||
print(f"✅ Index '{index_name}' built successfully!")
|
||||
|
||||
return leann_mv
|
||||
|
||||
def search(self, index_name: str, query: str, top_k: int = 5):
|
||||
"""
|
||||
Search the index with a text query.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to search
|
||||
query: Text query
|
||||
top_k: Number of results to return
|
||||
"""
|
||||
print(f"🔍 Searching '{index_name}' for: '{query}'")
|
||||
|
||||
# Load index
|
||||
leann_mv = LeannMultiVector(
|
||||
index_path=index_name,
|
||||
dim=128, # Will be updated when loading
|
||||
embedding_model_name=self.model_type,
|
||||
)
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = self._embed_query(query)
|
||||
|
||||
# Search (returns list of (score, doc_id) tuples)
|
||||
search_results = leann_mv.search(query_embedding.numpy(), topk=top_k)
|
||||
|
||||
# Display results
|
||||
print(f"\n📋 Top {len(search_results)} results:")
|
||||
for i, (score, doc_id) in enumerate(search_results, 1):
|
||||
# Get metadata for this doc_id (we need to load the metadata)
|
||||
print(f"{i}. Score: {score:.3f} | Doc ID: {doc_id}")
|
||||
|
||||
return search_results
|
||||
|
||||
def ask(self, index_name: str, interactive: bool = False):
|
||||
"""
|
||||
Interactive Q&A with the indexed documents.
|
||||
|
||||
Args:
|
||||
index_name: Name of the index to query
|
||||
interactive: Whether to run in interactive mode
|
||||
"""
|
||||
print(f"💬 ColQwen Chat with '{index_name}'")
|
||||
|
||||
if interactive:
|
||||
print("Type 'quit' to exit, 'help' for commands")
|
||||
while True:
|
||||
try:
|
||||
query = input("\n🤔 Your question: ").strip()
|
||||
if query.lower() in ["quit", "exit", "q"]:
|
||||
break
|
||||
elif query.lower() == "help":
|
||||
print("Commands: quit/exit/q (exit), help (this message)")
|
||||
continue
|
||||
elif not query:
|
||||
continue
|
||||
|
||||
self.search(index_name, query, top_k=3)
|
||||
|
||||
# TODO: Add answer generation with Qwen-VL
|
||||
print("\n💡 For detailed answers, we can integrate Qwen-VL here!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
else:
|
||||
query = input("🤔 Your question: ").strip()
|
||||
if query:
|
||||
self.search(index_name, query)
|
||||
|
||||
def _embed_images(self, images: list[Image.Image]) -> torch.Tensor:
|
||||
"""Generate embeddings for a list of images."""
|
||||
dataset = ListDataset(images)
|
||||
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=lambda x: x)
|
||||
|
||||
embeddings = []
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc="Embedding images"):
|
||||
batch_images = cast(list, batch)
|
||||
batch_inputs = self.processor.process_images(batch_images).to(self.device)
|
||||
batch_embeddings = self.model(**batch_inputs)
|
||||
embeddings.append(batch_embeddings.cpu())
|
||||
|
||||
return torch.cat(embeddings, dim=0)
|
||||
|
||||
def _embed_query(self, query: str) -> torch.Tensor:
|
||||
"""Generate embedding for a text query."""
|
||||
with torch.no_grad():
|
||||
query_inputs = self.processor.process_queries([query]).to(self.device)
|
||||
query_embedding = self.model(**query_inputs)
|
||||
return query_embedding.cpu()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ColQwen RAG - Easy multimodal PDF retrieval")
|
||||
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
||||
|
||||
# Build command
|
||||
build_parser = subparsers.add_parser("build", help="Build index from PDFs")
|
||||
build_parser.add_argument("--pdfs", required=True, help="Directory containing PDF files")
|
||||
build_parser.add_argument("--index", required=True, help="Index name")
|
||||
build_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
build_parser.add_argument("--pages-dir", help="Directory to save page images")
|
||||
|
||||
# Search command
|
||||
search_parser = subparsers.add_parser("search", help="Search the index")
|
||||
search_parser.add_argument("index", help="Index name")
|
||||
search_parser.add_argument("query", help="Search query")
|
||||
search_parser.add_argument("--top-k", type=int, default=5, help="Number of results")
|
||||
search_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
# Ask command
|
||||
ask_parser = subparsers.add_parser("ask", help="Interactive Q&A")
|
||||
ask_parser.add_argument("index", help="Index name")
|
||||
ask_parser.add_argument("--interactive", action="store_true", help="Interactive mode")
|
||||
ask_parser.add_argument(
|
||||
"--model", choices=["colqwen2", "colpali"], default="colqwen2", help="Model to use"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Initialize ColQwen RAG
|
||||
if args.command == "build":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
|
||||
# Get PDF files
|
||||
pdf_dir = Path(args.pdfs)
|
||||
if pdf_dir.is_file() and pdf_dir.suffix.lower() == ".pdf":
|
||||
pdf_paths = [str(pdf_dir)]
|
||||
elif pdf_dir.is_dir():
|
||||
pdf_paths = [str(p) for p in pdf_dir.glob("*.pdf")]
|
||||
else:
|
||||
print(f"❌ Invalid PDF path: {args.pdfs}")
|
||||
return
|
||||
|
||||
if not pdf_paths:
|
||||
print(f"❌ No PDF files found in {args.pdfs}")
|
||||
return
|
||||
|
||||
colqwen.build_index(pdf_paths, args.index, args.pages_dir)
|
||||
|
||||
elif args.command == "search":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.search(args.index, args.query, args.top_k)
|
||||
|
||||
elif args.command == "ask":
|
||||
colqwen = ColQwenRAG(args.model)
|
||||
colqwen.ask(args.index, args.interactive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
128
apps/document_rag.py
Normal file
128
apps/document_rag.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Document RAG example using the unified interface.
|
||||
Supports PDF, TXT, MD, and other document formats.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
|
||||
class DocumentRAG(BaseRAGExample):
|
||||
"""RAG example for document processing (PDF, TXT, MD, etc.)."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Document",
|
||||
description="Process and query documents (PDF, TXT, MD, etc.) with LEANN",
|
||||
default_index_name="test_doc_files",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add document-specific arguments."""
|
||||
doc_group = parser.add_argument_group("Document Parameters")
|
||||
doc_group.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="data",
|
||||
help="Directory containing documents to index (default: data)",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--file-types",
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Filter by file types (e.g., .pdf .txt .md). If not specified, all supported types are processed",
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=128, help="Text chunk overlap (default: 128)"
|
||||
)
|
||||
doc_group.add_argument(
|
||||
"--enable-code-chunking",
|
||||
action="store_true",
|
||||
help="Enable AST-aware chunking for code files in the data directory",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load documents and convert to text chunks."""
|
||||
print(f"Loading documents from: {args.data_dir}")
|
||||
if args.file_types:
|
||||
print(f"Filtering by file types: {args.file_types}")
|
||||
else:
|
||||
print("Processing all supported file types")
|
||||
|
||||
# Check if data directory exists
|
||||
data_path = Path(args.data_dir)
|
||||
if not data_path.exists():
|
||||
raise ValueError(f"Data directory not found: {args.data_dir}")
|
||||
|
||||
# Load documents
|
||||
documents = SimpleDirectoryReader(
|
||||
args.data_dir,
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=args.file_types if args.file_types else None,
|
||||
).load_data(show_progress=True)
|
||||
|
||||
if not documents:
|
||||
print(f"No documents found in {args.data_dir} with extensions {args.file_types}")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
# Determine chunking strategy
|
||||
use_ast = args.enable_code_chunking or getattr(args, "use_ast_chunking", False)
|
||||
|
||||
if use_ast:
|
||||
print("Using AST-aware chunking for code files")
|
||||
|
||||
# Convert to text chunks with optional AST support
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
use_ast_chunking=use_ast,
|
||||
ast_chunk_size=getattr(args, "ast_chunk_size", 512),
|
||||
ast_chunk_overlap=getattr(args, "ast_chunk_overlap", 64),
|
||||
code_file_extensions=getattr(args, "code_file_extensions", None),
|
||||
ast_fallback_traditional=getattr(args, "ast_fallback_traditional", True),
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0 and len(all_texts) > args.max_items:
|
||||
print(f"Limiting to {args.max_items} chunks (from {len(all_texts)})")
|
||||
all_texts = all_texts[: args.max_items]
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Example queries for document RAG
|
||||
print("\n📄 Document RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What are the main techniques LEANN uses?'")
|
||||
print("- 'What is the technique DLPM?'")
|
||||
print("- 'Who does Elizabeth Bennet marry?'")
|
||||
print(
|
||||
"- 'What is the problem of developing pan gu model Huawei meets? (盘古大模型开发中遇到什么问题?)'"
|
||||
)
|
||||
print("\n🚀 NEW: Code-aware chunking available!")
|
||||
print("- Use --enable-code-chunking to enable AST-aware chunking for code files")
|
||||
print("- Supports Python, Java, C#, TypeScript files")
|
||||
print("- Better semantic understanding of code structure")
|
||||
print("\nOr run without --query for interactive mode\n")
|
||||
|
||||
rag = DocumentRAG()
|
||||
asyncio.run(rag.run())
|
||||
167
apps/email_data/LEANN_email_reader.py
Normal file
167
apps/email_data/LEANN_email_reader.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import email
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
def find_all_messages_directories(root: str | None = None) -> list[Path]:
|
||||
"""
|
||||
Recursively find all 'Messages' directories under the given root.
|
||||
Returns a list of Path objects.
|
||||
"""
|
||||
if root is None:
|
||||
# Auto-detect user's mail path
|
||||
home_dir = os.path.expanduser("~")
|
||||
root = os.path.join(home_dir, "Library", "Mail")
|
||||
|
||||
messages_dirs = []
|
||||
for dirpath, _dirnames, _filenames in os.walk(root):
|
||||
if os.path.basename(dirpath) == "Messages":
|
||||
messages_dirs.append(Path(dirpath))
|
||||
return messages_dirs
|
||||
|
||||
|
||||
class EmlxReader(BaseReader):
|
||||
"""
|
||||
Apple Mail .emlx file reader with embedded metadata.
|
||||
|
||||
Reads individual .emlx files from Apple Mail's storage format.
|
||||
"""
|
||||
|
||||
def __init__(self, include_html: bool = False) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
include_html: Whether to include HTML content in the email body (default: False)
|
||||
"""
|
||||
self.include_html = include_html
|
||||
|
||||
def load_data(self, input_dir: str, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load data from the input directory containing .emlx files.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing .emlx files
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of messages to read.
|
||||
"""
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
count = 0
|
||||
total_files = 0
|
||||
successful_files = 0
|
||||
failed_files = 0
|
||||
|
||||
print(f"Starting to process directory: {input_dir}")
|
||||
|
||||
# Walk through the directory recursively
|
||||
for dirpath, dirnames, filenames in os.walk(input_dir):
|
||||
# Skip hidden directories
|
||||
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
|
||||
|
||||
for filename in filenames:
|
||||
# Check if we've reached the max count (skip if max_count == -1)
|
||||
if max_count > 0 and count >= max_count:
|
||||
break
|
||||
|
||||
if filename.endswith(".emlx"):
|
||||
total_files += 1
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(filepath, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
# .emlx files have a length prefix followed by the email content
|
||||
# The first line contains the length, followed by the email
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1]
|
||||
|
||||
# Parse the email using Python's email module
|
||||
try:
|
||||
msg = email.message_from_string(email_content)
|
||||
|
||||
# Extract email metadata
|
||||
subject = msg.get("Subject", "No Subject")
|
||||
from_addr = msg.get("From", "Unknown")
|
||||
to_addr = msg.get("To", "Unknown")
|
||||
date = msg.get("Date", "Unknown")
|
||||
|
||||
# Extract email body
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if (
|
||||
part.get_content_type() == "text/plain"
|
||||
or part.get_content_type() == "text/html"
|
||||
):
|
||||
if (
|
||||
part.get_content_type() == "text/html"
|
||||
and not self.include_html
|
||||
):
|
||||
continue
|
||||
try:
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload:
|
||||
body += payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding payload: {e}")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
payload = msg.get_payload(decode=True)
|
||||
if payload:
|
||||
body = payload.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
print(f"Error decoding single part payload: {e}")
|
||||
body = ""
|
||||
|
||||
# Only create document if we have some content
|
||||
if body.strip() or subject != "No Subject":
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[File]: {filename}
|
||||
[From]: {from_addr}
|
||||
[To]: {to_addr}
|
||||
[Subject]: {subject}
|
||||
[Date]: {date}
|
||||
[EMAIL BODY Start]:
|
||||
{body}
|
||||
"""
|
||||
|
||||
# No separate metadata - everything is in the text
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
successful_files += 1
|
||||
|
||||
# Print first few successful files for debugging
|
||||
if successful_files <= 3:
|
||||
print(
|
||||
f"Successfully loaded: {filename} - Subject: {subject[:50]}..."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error parsing email from {filepath}: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
failed_files += 1
|
||||
if failed_files <= 5: # Only print first few errors
|
||||
print(f"Error reading file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
print("Processing summary:")
|
||||
print(f" Total .emlx files found: {total_files}")
|
||||
print(f" Successfully loaded: {successful_files}")
|
||||
print(f" Failed to load: {failed_files}")
|
||||
print(f" Final documents: {len(docs)}")
|
||||
|
||||
return docs
|
||||
@@ -7,9 +7,9 @@ Contains simple parser for mbox files.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fsspec import AbstractFileSystem
|
||||
from typing import Any
|
||||
|
||||
from fsspec import AbstractFileSystem
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.schema import Document
|
||||
|
||||
@@ -27,11 +27,7 @@ class MboxReader(BaseReader):
|
||||
"""
|
||||
|
||||
DEFAULT_MESSAGE_FORMAT: str = (
|
||||
"Date: {_date}\n"
|
||||
"From: {_from}\n"
|
||||
"To: {_to}\n"
|
||||
"Subject: {_subject}\n"
|
||||
"Content: {_content}"
|
||||
"Date: {_date}\nFrom: {_from}\nTo: {_to}\nSubject: {_subject}\nContent: {_content}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -45,9 +41,7 @@ class MboxReader(BaseReader):
|
||||
try:
|
||||
from bs4 import BeautifulSoup # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`beautifulsoup4` package not found: `pip install beautifulsoup4`"
|
||||
)
|
||||
raise ImportError("`beautifulsoup4` package not found: `pip install beautifulsoup4`")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_count = max_count
|
||||
@@ -56,9 +50,9 @@ class MboxReader(BaseReader):
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
extra_info: dict | None = None,
|
||||
fs: AbstractFileSystem | None = None,
|
||||
) -> list[Document]:
|
||||
"""Parse file into string."""
|
||||
# Import required libraries
|
||||
import mailbox
|
||||
@@ -74,7 +68,7 @@ class MboxReader(BaseReader):
|
||||
)
|
||||
|
||||
i = 0
|
||||
results: List[str] = []
|
||||
results: list[str] = []
|
||||
# Load file using mailbox
|
||||
bytes_parser = BytesParser(policy=default).parse
|
||||
mbox = mailbox.mbox(file, factory=bytes_parser) # type: ignore
|
||||
@@ -124,7 +118,7 @@ class MboxReader(BaseReader):
|
||||
class EmlxMboxReader(MboxReader):
|
||||
"""
|
||||
EmlxMboxReader - Modified MboxReader that handles directories of .emlx files.
|
||||
|
||||
|
||||
Extends MboxReader to work with Apple Mail's .emlx format by:
|
||||
1. Reading .emlx files from a directory
|
||||
2. Converting them to mbox format in memory
|
||||
@@ -133,14 +127,15 @@ class EmlxMboxReader(MboxReader):
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
directory: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
file: Path, # Note: for EmlxMboxReader, this is actually a directory
|
||||
extra_info: dict | None = None,
|
||||
fs: AbstractFileSystem | None = None,
|
||||
) -> list[Document]:
|
||||
"""Parse .emlx files from directory into strings using MboxReader logic."""
|
||||
import tempfile
|
||||
directory = file # Rename for clarity - this is a directory of .emlx files
|
||||
import os
|
||||
|
||||
import tempfile
|
||||
|
||||
if fs:
|
||||
logger.warning(
|
||||
"fs was specified but EmlxMboxReader doesn't support loading "
|
||||
@@ -150,37 +145,37 @@ class EmlxMboxReader(MboxReader):
|
||||
# Find all .emlx files in the directory
|
||||
emlx_files = list(directory.glob("*.emlx"))
|
||||
logger.info(f"Found {len(emlx_files)} .emlx files in {directory}")
|
||||
|
||||
|
||||
if not emlx_files:
|
||||
logger.warning(f"No .emlx files found in {directory}")
|
||||
return []
|
||||
|
||||
# Create a temporary mbox file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.mbox', delete=False) as temp_mbox:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".mbox", delete=False) as temp_mbox:
|
||||
temp_mbox_path = temp_mbox.name
|
||||
|
||||
|
||||
# Convert .emlx files to mbox format
|
||||
for emlx_file in emlx_files:
|
||||
try:
|
||||
# Read the .emlx file
|
||||
with open(emlx_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
with open(emlx_file, encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
|
||||
|
||||
# .emlx format: first line is length, rest is email content
|
||||
lines = content.split('\n', 1)
|
||||
lines = content.split("\n", 1)
|
||||
if len(lines) >= 2:
|
||||
email_content = lines[1] # Skip the length line
|
||||
|
||||
|
||||
# Write to mbox format (each message starts with "From " and ends with blank line)
|
||||
temp_mbox.write(f"From {emlx_file.name} {email_content}\n\n")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process {emlx_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Close the temporary file so MboxReader can read it
|
||||
temp_mbox.close()
|
||||
|
||||
|
||||
try:
|
||||
# Use the parent MboxReader's logic to parse the mbox file
|
||||
return super().load_data(Path(temp_mbox_path), extra_info, fs)
|
||||
@@ -188,5 +183,5 @@ class EmlxMboxReader(MboxReader):
|
||||
# Clean up temporary file
|
||||
try:
|
||||
os.unlink(temp_mbox_path)
|
||||
except:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
158
apps/email_rag.py
Normal file
158
apps/email_rag.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Email RAG example using the unified interface.
|
||||
Supports Apple Mail on macOS.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
from chunking import create_text_chunks
|
||||
|
||||
from .email_data.LEANN_email_reader import EmlxReader
|
||||
|
||||
|
||||
class EmailRAG(BaseRAGExample):
|
||||
"""RAG example for Apple Mail processing."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Process all emails by default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="Email",
|
||||
description="Process and query Apple Mail emails with LEANN",
|
||||
default_index_name="mail_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add email-specific arguments."""
|
||||
email_group = parser.add_argument_group("Email Parameters")
|
||||
email_group.add_argument(
|
||||
"--mail-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Apple Mail directory (auto-detected if not specified)",
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--include-html", action="store_true", help="Include HTML content in email processing"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-size", type=int, default=256, help="Text chunk size (default: 256)"
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=25, help="Text chunk overlap (default: 25)"
|
||||
)
|
||||
|
||||
def _find_mail_directories(self) -> list[Path]:
|
||||
"""Auto-detect all Apple Mail directories."""
|
||||
mail_base = Path.home() / "Library" / "Mail"
|
||||
if not mail_base.exists():
|
||||
return []
|
||||
|
||||
# Find all Messages directories
|
||||
messages_dirs = []
|
||||
for item in mail_base.rglob("Messages"):
|
||||
if item.is_dir():
|
||||
messages_dirs.append(item)
|
||||
|
||||
return messages_dirs
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load emails and convert to text chunks."""
|
||||
# Determine mail directories
|
||||
if args.mail_path:
|
||||
messages_dirs = [Path(args.mail_path)]
|
||||
else:
|
||||
print("Auto-detecting Apple Mail directories...")
|
||||
messages_dirs = self._find_mail_directories()
|
||||
|
||||
if not messages_dirs:
|
||||
print("No Apple Mail directories found!")
|
||||
print("Please specify --mail-path manually")
|
||||
return []
|
||||
|
||||
print(f"Found {len(messages_dirs)} mail directories")
|
||||
|
||||
# Create reader
|
||||
reader = EmlxReader(include_html=args.include_html)
|
||||
|
||||
# Process each directory
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, messages_dir in enumerate(messages_dirs):
|
||||
print(f"\nProcessing directory {i + 1}/{len(messages_dirs)}: {messages_dir}")
|
||||
|
||||
try:
|
||||
# Count emlx files
|
||||
emlx_files = list(messages_dir.glob("*.emlx"))
|
||||
print(f"Found {len(emlx_files)} email files")
|
||||
|
||||
# Apply max_items limit per directory
|
||||
max_per_dir = -1 # Default to process all
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_dir = remaining
|
||||
# If args.max_items == -1, max_per_dir stays -1 (process all)
|
||||
|
||||
# Load emails - fix the parameter passing
|
||||
documents = reader.load_data(
|
||||
input_dir=str(messages_dir),
|
||||
max_count=max_per_dir,
|
||||
)
|
||||
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
print(f"Processed {len(documents)} emails from this directory")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {messages_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No emails found to process!")
|
||||
return []
|
||||
|
||||
print(f"\nTotal emails processed: {len(all_documents)}")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks
|
||||
# Email reader uses chunk_overlap=25 as in original
|
||||
all_texts = create_text_chunks(
|
||||
all_documents, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: This example is designed for macOS (Apple Mail)")
|
||||
print(" Windows/Linux support coming soon!\n")
|
||||
|
||||
# Example queries for email RAG
|
||||
print("\n📧 Email RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'What did my boss say about deadlines?'")
|
||||
print("- 'Find emails about travel expenses'")
|
||||
print("- 'Show me emails from last month about the project'")
|
||||
print("- 'What food did I order from DoorDash?'")
|
||||
print("\nNote: You may need to grant Full Disk Access to your terminal\n")
|
||||
|
||||
rag = EmailRAG()
|
||||
asyncio.run(rag.run())
|
||||
@@ -1,3 +1,3 @@
|
||||
from .history import ChromeHistoryReader
|
||||
|
||||
__all__ = ['ChromeHistoryReader']
|
||||
__all__ = ["ChromeHistoryReader"]
|
||||
@@ -1,77 +1,81 @@
|
||||
import sqlite3
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import List, Any
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class ChromeHistoryReader(BaseReader):
|
||||
"""
|
||||
Chrome browser history reader that extracts browsing data from SQLite database.
|
||||
|
||||
|
||||
Reads Chrome history from the default Chrome profile location and creates documents
|
||||
with embedded metadata similar to the email reader structure.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
pass
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load Chrome history data from the default Chrome profile location.
|
||||
|
||||
|
||||
Args:
|
||||
input_dir: Not used for Chrome history (kept for compatibility)
|
||||
**load_kwargs:
|
||||
max_count (int): Maximum amount of history entries to read.
|
||||
chrome_profile_path (str): Custom path to Chrome profile directory.
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
chrome_profile_path = load_kwargs.get('chrome_profile_path', None)
|
||||
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
chrome_profile_path = load_kwargs.get("chrome_profile_path", None)
|
||||
|
||||
# Default Chrome profile path on macOS
|
||||
if chrome_profile_path is None:
|
||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
|
||||
chrome_profile_path = os.path.expanduser(
|
||||
"~/Library/Application Support/Google/Chrome/Default"
|
||||
)
|
||||
|
||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||
|
||||
|
||||
if not os.path.exists(history_db_path):
|
||||
print(f"Chrome history database not found at: {history_db_path}")
|
||||
return docs
|
||||
|
||||
|
||||
try:
|
||||
# Connect to the Chrome history database
|
||||
print(f"Connecting to database: {history_db_path}")
|
||||
conn = sqlite3.connect(history_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
# Query to get browsing history with metadata (removed created_time column)
|
||||
query = """
|
||||
SELECT
|
||||
SELECT
|
||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
hidden
|
||||
FROM urls
|
||||
FROM urls
|
||||
ORDER BY last_visit_time DESC
|
||||
"""
|
||||
|
||||
|
||||
print(f"Executing query on database: {history_db_path}")
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
print(f"Query returned {len(rows)} rows")
|
||||
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
|
||||
|
||||
last_visit, url, title, visit_count, typed_count, _hidden = row
|
||||
|
||||
# Create document content with metadata embedded in text
|
||||
doc_content = f"""
|
||||
[Title]: {title}
|
||||
@@ -80,38 +84,43 @@ class ChromeHistoryReader(BaseReader):
|
||||
[Visit times]: {visit_count}
|
||||
[Typed times]: {typed_count}
|
||||
"""
|
||||
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={ "title": title[0:150]})
|
||||
doc = Document(text=doc_content, metadata={"title": title[0:150]})
|
||||
# if len(title) > 150:
|
||||
# print(f"Title is too long: {title}")
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
conn.close()
|
||||
print(f"Loaded {len(docs)} Chrome history documents")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading Chrome history: {e}")
|
||||
# add you may need to close your browser to make the database file available
|
||||
# also highlight in red
|
||||
print(
|
||||
"\033[91mYou may need to close your browser to make the database file available\033[0m"
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def find_chrome_profiles() -> List[Path]:
|
||||
def find_chrome_profiles() -> list[Path]:
|
||||
"""
|
||||
Find all Chrome profile directories.
|
||||
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to Chrome profile directories
|
||||
"""
|
||||
chrome_base_path = Path(os.path.expanduser("~/Library/Application Support/Google/Chrome"))
|
||||
profile_dirs = []
|
||||
|
||||
|
||||
if not chrome_base_path.exists():
|
||||
print(f"Chrome directory not found at: {chrome_base_path}")
|
||||
return profile_dirs
|
||||
|
||||
|
||||
# Find all profile directories
|
||||
for profile_dir in chrome_base_path.iterdir():
|
||||
if profile_dir.is_dir() and profile_dir.name != "System Profile":
|
||||
@@ -119,53 +128,59 @@ class ChromeHistoryReader(BaseReader):
|
||||
if history_path.exists():
|
||||
profile_dirs.append(profile_dir)
|
||||
print(f"Found Chrome profile: {profile_dir}")
|
||||
|
||||
|
||||
print(f"Found {len(profile_dirs)} Chrome profiles")
|
||||
return profile_dirs
|
||||
|
||||
@staticmethod
|
||||
def export_history_to_file(output_file: str = "chrome_history_export.txt", max_count: int = 1000):
|
||||
def export_history_to_file(
|
||||
output_file: str = "chrome_history_export.txt", max_count: int = 1000
|
||||
):
|
||||
"""
|
||||
Export Chrome history to a text file using the same SQL query format.
|
||||
|
||||
|
||||
Args:
|
||||
output_file: Path to the output file
|
||||
max_count: Maximum number of entries to export
|
||||
"""
|
||||
chrome_profile_path = os.path.expanduser("~/Library/Application Support/Google/Chrome/Default")
|
||||
chrome_profile_path = os.path.expanduser(
|
||||
"~/Library/Application Support/Google/Chrome/Default"
|
||||
)
|
||||
history_db_path = os.path.join(chrome_profile_path, "History")
|
||||
|
||||
|
||||
if not os.path.exists(history_db_path):
|
||||
print(f"Chrome history database not found at: {history_db_path}")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(history_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
|
||||
query = """
|
||||
SELECT
|
||||
SELECT
|
||||
datetime(last_visit_time/1000000-11644473600,'unixepoch','localtime') as last_visit,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
url,
|
||||
title,
|
||||
visit_count,
|
||||
typed_count,
|
||||
hidden
|
||||
FROM urls
|
||||
FROM urls
|
||||
ORDER BY last_visit_time DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
|
||||
cursor.execute(query, (max_count,))
|
||||
rows = cursor.fetchall()
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
last_visit, url, title, visit_count, typed_count, hidden = row
|
||||
f.write(f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n")
|
||||
|
||||
f.write(
|
||||
f"{last_visit}\t{url}\t{title}\t{visit_count}\t{typed_count}\t{hidden}\n"
|
||||
)
|
||||
|
||||
conn.close()
|
||||
print(f"Exported {len(rows)} history entries to {output_file}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting Chrome history: {e}")
|
||||
print(f"Error exporting Chrome history: {e}")
|
||||
@@ -2,30 +2,31 @@ import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class WeChatHistoryReader(BaseReader):
|
||||
"""
|
||||
WeChat chat history reader that extracts chat data from exported JSON files.
|
||||
|
||||
|
||||
Reads WeChat chat history from exported JSON files (from wechat-exporter tool)
|
||||
and creates documents with embedded metadata similar to the Chrome history reader structure.
|
||||
|
||||
|
||||
Also includes utilities for automatic WeChat chat history export.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
self.packages_dir = Path(__file__).parent.parent.parent / "packages"
|
||||
self.wechat_exporter_dir = self.packages_dir / "wechat-exporter"
|
||||
self.wechat_decipher_dir = self.packages_dir / "wechat-decipher-macos"
|
||||
|
||||
|
||||
def check_wechat_running(self) -> bool:
|
||||
"""Check if WeChat is currently running."""
|
||||
try:
|
||||
@@ -33,24 +34,30 @@ class WeChatHistoryReader(BaseReader):
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def install_wechattweak(self) -> bool:
|
||||
"""Install WeChatTweak CLI tool."""
|
||||
try:
|
||||
# Create wechat-exporter directory if it doesn't exist
|
||||
self.wechat_exporter_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
wechattweak_path = self.wechat_exporter_dir / "wechattweak-cli"
|
||||
if not wechattweak_path.exists():
|
||||
print("Downloading WeChatTweak CLI...")
|
||||
subprocess.run([
|
||||
"curl", "-L", "-o", str(wechattweak_path),
|
||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli"
|
||||
], check=True)
|
||||
|
||||
subprocess.run(
|
||||
[
|
||||
"curl",
|
||||
"-L",
|
||||
"-o",
|
||||
str(wechattweak_path),
|
||||
"https://github.com/JettChenT/WeChatTweak-CLI/releases/latest/download/wechattweak-cli",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Make executable
|
||||
wechattweak_path.chmod(0o755)
|
||||
|
||||
|
||||
# Install WeChatTweak
|
||||
print("Installing WeChatTweak...")
|
||||
subprocess.run(["sudo", str(wechattweak_path), "install"], check=True)
|
||||
@@ -58,7 +65,7 @@ class WeChatHistoryReader(BaseReader):
|
||||
except Exception as e:
|
||||
print(f"Error installing WeChatTweak: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def restart_wechat(self):
|
||||
"""Restart WeChat to apply WeChatTweak."""
|
||||
try:
|
||||
@@ -69,302 +76,327 @@ class WeChatHistoryReader(BaseReader):
|
||||
time.sleep(5) # Wait for WeChat to start
|
||||
except Exception as e:
|
||||
print(f"Error restarting WeChat: {e}")
|
||||
|
||||
|
||||
def check_api_available(self) -> bool:
|
||||
"""Check if WeChatTweak API is available."""
|
||||
try:
|
||||
result = subprocess.run([
|
||||
"curl", "-s", "http://localhost:48065/wechat/allcontacts"
|
||||
], capture_output=True, text=True, timeout=5)
|
||||
return result.returncode == 0 and result.stdout.strip()
|
||||
result = subprocess.run(
|
||||
["curl", "-s", "http://localhost:48065/wechat/allcontacts"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0 and bool(result.stdout.strip())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
def _extract_readable_text(self, content: str) -> str:
|
||||
"""
|
||||
Extract readable text from message content, removing XML and system messages.
|
||||
|
||||
|
||||
Args:
|
||||
content: The raw message content (can be string or dict)
|
||||
|
||||
|
||||
Returns:
|
||||
Cleaned, readable text
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
|
||||
# Handle dictionary content (like quoted messages)
|
||||
if isinstance(content, dict):
|
||||
# Extract text from dictionary structure
|
||||
text_parts = []
|
||||
if 'title' in content:
|
||||
text_parts.append(str(content['title']))
|
||||
if 'quoted' in content:
|
||||
text_parts.append(str(content['quoted']))
|
||||
if 'content' in content:
|
||||
text_parts.append(str(content['content']))
|
||||
if 'text' in content:
|
||||
text_parts.append(str(content['text']))
|
||||
|
||||
if "title" in content:
|
||||
text_parts.append(str(content["title"]))
|
||||
if "quoted" in content:
|
||||
text_parts.append(str(content["quoted"]))
|
||||
if "content" in content:
|
||||
text_parts.append(str(content["content"]))
|
||||
if "text" in content:
|
||||
text_parts.append(str(content["text"]))
|
||||
|
||||
if text_parts:
|
||||
return " | ".join(text_parts)
|
||||
else:
|
||||
# If we can't extract meaningful text from dict, return empty
|
||||
return ""
|
||||
|
||||
|
||||
# Handle string content
|
||||
if not isinstance(content, str):
|
||||
return ""
|
||||
|
||||
|
||||
# Remove common prefixes like "wxid_xxx:\n"
|
||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||
|
||||
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||
|
||||
# If it's just XML or system message, return empty
|
||||
if clean_content.strip().startswith('<') or 'recalled a message' in clean_content:
|
||||
if clean_content.strip().startswith("<") or "recalled a message" in clean_content:
|
||||
return ""
|
||||
|
||||
|
||||
return clean_content.strip()
|
||||
|
||||
|
||||
def _is_text_message(self, content: str) -> bool:
|
||||
"""
|
||||
Check if a message contains readable text content.
|
||||
|
||||
|
||||
Args:
|
||||
content: The message content (can be string or dict)
|
||||
|
||||
|
||||
Returns:
|
||||
True if the message contains readable text, False otherwise
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
|
||||
# Handle dictionary content
|
||||
if isinstance(content, dict):
|
||||
# Check if dict has any readable text fields
|
||||
text_fields = ['title', 'quoted', 'content', 'text']
|
||||
text_fields = ["title", "quoted", "content", "text"]
|
||||
for field in text_fields:
|
||||
if field in content and content[field]:
|
||||
if content.get(field):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Handle string content
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
|
||||
# Skip image messages (contain XML with img tags)
|
||||
if '<img' in content and 'cdnurl' in content:
|
||||
if "<img" in content and "cdnurl" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip emoji messages (contain emoji XML tags)
|
||||
if '<emoji' in content and 'productid' in content:
|
||||
if "<emoji" in content and "productid" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip voice messages
|
||||
if '<voice' in content:
|
||||
if "<voice" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip video messages
|
||||
if '<video' in content:
|
||||
if "<video" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip file messages
|
||||
if '<appmsg' in content and 'appid' in content:
|
||||
if "<appmsg" in content and "appid" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Skip system messages (like "recalled a message")
|
||||
if 'recalled a message' in content:
|
||||
if "recalled a message" in content:
|
||||
return False
|
||||
|
||||
|
||||
# Check if there's actual readable text (not just XML or system messages)
|
||||
# Remove common prefixes like "wxid_xxx:\n" and check for actual content
|
||||
clean_content = re.sub(r'^wxid_[^:]+:\s*', '', content)
|
||||
clean_content = re.sub(r'^[^:]+:\s*', '', clean_content)
|
||||
|
||||
clean_content = re.sub(r"^wxid_[^:]+:\s*", "", content)
|
||||
clean_content = re.sub(r"^[^:]+:\s*", "", clean_content)
|
||||
|
||||
# If after cleaning we have meaningful text, consider it readable
|
||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith('<'):
|
||||
if len(clean_content.strip()) > 0 and not clean_content.strip().startswith("<"):
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
def _concatenate_messages(self, messages: List[Dict], max_length: int = 128,
|
||||
time_window_minutes: int = 30, overlap_messages: int = 0) -> List[Dict]:
|
||||
|
||||
def _concatenate_messages(
|
||||
self,
|
||||
messages: list[dict],
|
||||
max_length: int = 128,
|
||||
time_window_minutes: int = 30,
|
||||
overlap_messages: int = 0,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Concatenate messages based on length and time rules.
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
max_length: Maximum length for concatenated message groups. Use -1 to disable length constraint.
|
||||
time_window_minutes: Time window in minutes to group messages together. Use -1 to disable time constraint.
|
||||
overlap_messages: Number of messages to overlap between consecutive groups
|
||||
|
||||
|
||||
Returns:
|
||||
List of concatenated message groups
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
|
||||
concatenated_groups = []
|
||||
current_group = []
|
||||
current_length = 0
|
||||
last_timestamp = None
|
||||
|
||||
|
||||
for message in messages:
|
||||
# Extract message info
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
from_user = message.get('fromUser', '')
|
||||
to_user = message.get('toUser', '')
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
content = message.get("content", "")
|
||||
message_text = message.get("message", "")
|
||||
create_time = message.get("createTime", 0)
|
||||
message.get("fromUser", "")
|
||||
message.get("toUser", "")
|
||||
message.get("isSentFromSelf", False)
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text:
|
||||
readable_text = message_text
|
||||
|
||||
|
||||
# Skip empty messages
|
||||
if not readable_text.strip():
|
||||
continue
|
||||
|
||||
|
||||
# Check time window constraint (only if time_window_minutes != -1)
|
||||
if time_window_minutes != -1 and last_timestamp is not None and create_time > 0:
|
||||
time_diff_minutes = (create_time - last_timestamp) / 60
|
||||
if time_diff_minutes > time_window_minutes:
|
||||
# Time gap too large, start new group
|
||||
if current_group:
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
concatenated_groups.append(
|
||||
{
|
||||
"messages": current_group,
|
||||
"total_length": current_length,
|
||||
"start_time": current_group[0].get("createTime", 0),
|
||||
"end_time": current_group[-1].get("createTime", 0),
|
||||
}
|
||||
)
|
||||
# Keep last few messages for overlap
|
||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||
current_group = current_group[-overlap_messages:]
|
||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||
current_length = sum(
|
||||
len(
|
||||
self._extract_readable_text(msg.get("content", ""))
|
||||
or msg.get("message", "")
|
||||
)
|
||||
for msg in current_group
|
||||
)
|
||||
else:
|
||||
current_group = []
|
||||
current_length = 0
|
||||
|
||||
|
||||
# Check length constraint (only if max_length != -1)
|
||||
message_length = len(readable_text)
|
||||
if max_length != -1 and current_length + message_length > max_length and current_group:
|
||||
# Current group would exceed max length, save it and start new
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
concatenated_groups.append(
|
||||
{
|
||||
"messages": current_group,
|
||||
"total_length": current_length,
|
||||
"start_time": current_group[0].get("createTime", 0),
|
||||
"end_time": current_group[-1].get("createTime", 0),
|
||||
}
|
||||
)
|
||||
# Keep last few messages for overlap
|
||||
if overlap_messages > 0 and len(current_group) > overlap_messages:
|
||||
current_group = current_group[-overlap_messages:]
|
||||
current_length = sum(len(self._extract_readable_text(msg.get('content', '')) or msg.get('message', '')) for msg in current_group)
|
||||
current_length = sum(
|
||||
len(
|
||||
self._extract_readable_text(msg.get("content", ""))
|
||||
or msg.get("message", "")
|
||||
)
|
||||
for msg in current_group
|
||||
)
|
||||
else:
|
||||
current_group = []
|
||||
current_length = 0
|
||||
|
||||
|
||||
# Add message to current group
|
||||
current_group.append(message)
|
||||
current_length += message_length
|
||||
last_timestamp = create_time
|
||||
|
||||
|
||||
# Add the last group if it exists
|
||||
if current_group:
|
||||
concatenated_groups.append({
|
||||
'messages': current_group,
|
||||
'total_length': current_length,
|
||||
'start_time': current_group[0].get('createTime', 0),
|
||||
'end_time': current_group[-1].get('createTime', 0)
|
||||
})
|
||||
|
||||
concatenated_groups.append(
|
||||
{
|
||||
"messages": current_group,
|
||||
"total_length": current_length,
|
||||
"start_time": current_group[0].get("createTime", 0),
|
||||
"end_time": current_group[-1].get("createTime", 0),
|
||||
}
|
||||
)
|
||||
|
||||
return concatenated_groups
|
||||
|
||||
def _create_concatenated_content(self, message_group: Dict, contact_name: str) -> str:
|
||||
|
||||
def _create_concatenated_content(
|
||||
self, message_group: dict, contact_name: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Create concatenated content from a group of messages.
|
||||
|
||||
|
||||
Args:
|
||||
message_group: Dictionary containing messages and metadata
|
||||
contact_name: Name of the contact
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted concatenated content
|
||||
"""
|
||||
messages = message_group['messages']
|
||||
start_time = message_group['start_time']
|
||||
end_time = message_group['end_time']
|
||||
|
||||
messages = message_group["messages"]
|
||||
start_time = message_group["start_time"]
|
||||
end_time = message_group["end_time"]
|
||||
|
||||
# Format timestamps
|
||||
if start_time:
|
||||
try:
|
||||
start_timestamp = datetime.fromtimestamp(start_time)
|
||||
start_time_str = start_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
start_time_str = start_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
start_time_str = str(start_time)
|
||||
else:
|
||||
start_time_str = "Unknown"
|
||||
|
||||
|
||||
if end_time:
|
||||
try:
|
||||
end_timestamp = datetime.fromtimestamp(end_time)
|
||||
end_time_str = end_timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
end_time_str = end_timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
end_time_str = str(end_time)
|
||||
else:
|
||||
end_time_str = "Unknown"
|
||||
|
||||
|
||||
# Build concatenated message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
content = message.get("content", "")
|
||||
message_text = message.get("message", "")
|
||||
create_time = message.get("createTime", 0)
|
||||
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text:
|
||||
readable_text = message_text
|
||||
|
||||
|
||||
# Format individual message
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
# change to YYYY-MM-DD HH:MM:SS
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
|
||||
sender = "[Me]" if is_sent_from_self else "[Contact]"
|
||||
message_parts.append(f"({time_str}) {sender}: {readable_text}")
|
||||
|
||||
|
||||
concatenated_text = "\n".join(message_parts)
|
||||
|
||||
|
||||
# Create final document content
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
Time Range: {start_time_str} - {end_time_str}
|
||||
Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
Messages ({len(messages)} messages, {message_group["total_length"]} chars):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
# TODO @yichuan give better format and rich info here!
|
||||
# TODO @yichuan give better format and rich info here!
|
||||
doc_content = f"""
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content, contact_name
|
||||
|
||||
def load_data(self, input_dir: str = None, **load_kwargs: Any) -> List[Document]:
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load WeChat chat history data from exported JSON files.
|
||||
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing exported WeChat JSON files
|
||||
**load_kwargs:
|
||||
@@ -376,97 +408,104 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
time_window_minutes (int): Time window in minutes to group messages together (default: 30).
|
||||
overlap_messages (int): Number of messages to overlap between consecutive groups (default: 2).
|
||||
"""
|
||||
docs: List[Document] = []
|
||||
max_count = load_kwargs.get('max_count', 1000)
|
||||
wechat_export_dir = load_kwargs.get('wechat_export_dir', None)
|
||||
include_non_text = load_kwargs.get('include_non_text', False)
|
||||
concatenate_messages = load_kwargs.get('concatenate_messages', False)
|
||||
max_length = load_kwargs.get('max_length', 1000)
|
||||
time_window_minutes = load_kwargs.get('time_window_minutes', 30)
|
||||
|
||||
docs: list[Document] = []
|
||||
max_count = load_kwargs.get("max_count", 1000)
|
||||
wechat_export_dir = load_kwargs.get("wechat_export_dir", None)
|
||||
include_non_text = load_kwargs.get("include_non_text", False)
|
||||
concatenate_messages = load_kwargs.get("concatenate_messages", False)
|
||||
max_length = load_kwargs.get("max_length", 1000)
|
||||
time_window_minutes = load_kwargs.get("time_window_minutes", 30)
|
||||
|
||||
# Default WeChat export path
|
||||
if wechat_export_dir is None:
|
||||
wechat_export_dir = "./wechat_export_test"
|
||||
|
||||
|
||||
if not os.path.exists(wechat_export_dir):
|
||||
print(f"WeChat export directory not found at: {wechat_export_dir}")
|
||||
return docs
|
||||
|
||||
|
||||
try:
|
||||
# Find all JSON files in the export directory
|
||||
json_files = list(Path(wechat_export_dir).glob("*.json"))
|
||||
print(f"Found {len(json_files)} WeChat chat history files")
|
||||
|
||||
|
||||
count = 0
|
||||
for json_file in json_files:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
with open(json_file, encoding="utf-8") as f:
|
||||
chat_data = json.load(f)
|
||||
|
||||
|
||||
# Extract contact name from filename
|
||||
contact_name = json_file.stem
|
||||
|
||||
|
||||
if concatenate_messages:
|
||||
# Filter messages to only include readable text messages
|
||||
readable_messages = []
|
||||
for message in chat_data:
|
||||
try:
|
||||
content = message.get('content', '')
|
||||
content = message.get("content", "")
|
||||
if not include_non_text and not self._is_text_message(content):
|
||||
continue
|
||||
|
||||
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text and not include_non_text:
|
||||
continue
|
||||
|
||||
|
||||
readable_messages.append(message)
|
||||
except Exception as e:
|
||||
print(f"Error processing message in {json_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Concatenate messages based on rules
|
||||
message_groups = self._concatenate_messages(
|
||||
readable_messages,
|
||||
max_length=-1,
|
||||
time_window_minutes=-1,
|
||||
overlap_messages=0 # Keep 2 messages overlap between groups
|
||||
readable_messages,
|
||||
max_length=max_length,
|
||||
time_window_minutes=time_window_minutes,
|
||||
overlap_messages=0, # No overlap between groups
|
||||
)
|
||||
|
||||
|
||||
# Create documents from concatenated groups
|
||||
for message_group in message_groups:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
doc_content, contact_name = self._create_concatenated_content(message_group, contact_name)
|
||||
doc = Document(text=doc_content, metadata={"contact_name": contact_name})
|
||||
|
||||
doc_content, contact_name = self._create_concatenated_content(
|
||||
message_group, contact_name
|
||||
)
|
||||
doc = Document(
|
||||
text=doc_content,
|
||||
metadata={"contact_name": contact_name},
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
print(f"Created {len(message_groups)} concatenated message groups for {contact_name}")
|
||||
|
||||
|
||||
print(
|
||||
f"Created {len(message_groups)} concatenated message groups for {contact_name}"
|
||||
)
|
||||
|
||||
else:
|
||||
# Original single-message processing
|
||||
for message in chat_data:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
|
||||
# Extract message information
|
||||
from_user = message.get('fromUser', '')
|
||||
to_user = message.get('toUser', '')
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
is_sent_from_self = message.get('isSentFromSelf', False)
|
||||
|
||||
message.get("fromUser", "")
|
||||
message.get("toUser", "")
|
||||
content = message.get("content", "")
|
||||
message_text = message.get("message", "")
|
||||
create_time = message.get("createTime", 0)
|
||||
is_sent_from_self = message.get("isSentFromSelf", False)
|
||||
|
||||
# Handle content that might be dict or string
|
||||
try:
|
||||
# Check if this is a readable text message
|
||||
if not include_non_text and not self._is_text_message(content):
|
||||
continue
|
||||
|
||||
|
||||
# Extract readable text
|
||||
readable_text = self._extract_readable_text(content)
|
||||
if not readable_text and not include_non_text:
|
||||
@@ -475,17 +514,17 @@ Messages ({len(messages)} messages, {message_group['total_length']} chars):
|
||||
# Skip messages that cause processing errors
|
||||
print(f"Error processing message in {json_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Convert timestamp to readable format
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
|
||||
# Create document content with metadata header and contact info
|
||||
doc_content = f"""
|
||||
Contact: {contact_name}
|
||||
@@ -493,57 +532,66 @@ Is sent from self: {is_sent_from_self}
|
||||
Time: {time_str}
|
||||
Message: {readable_text if readable_text else message_text}
|
||||
"""
|
||||
|
||||
|
||||
# Create document with embedded metadata
|
||||
doc = Document(text=doc_content, metadata={})
|
||||
doc = Document(
|
||||
text=doc_content, metadata={"contact_name": contact_name}
|
||||
)
|
||||
docs.append(doc)
|
||||
count += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading {json_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
print(f"Loaded {len(docs)} WeChat chat documents")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading WeChat history: {e}")
|
||||
return docs
|
||||
|
||||
|
||||
return docs
|
||||
|
||||
@staticmethod
|
||||
def find_wechat_export_dirs() -> List[Path]:
|
||||
def find_wechat_export_dirs() -> list[Path]:
|
||||
"""
|
||||
Find all WeChat export directories.
|
||||
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to WeChat export directories
|
||||
"""
|
||||
export_dirs = []
|
||||
|
||||
|
||||
# Look for common export directory names
|
||||
possible_dirs = [
|
||||
Path("./wechat_export_test"),
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_export_direct"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export")
|
||||
Path("./chat_export"),
|
||||
]
|
||||
|
||||
|
||||
for export_dir in possible_dirs:
|
||||
if export_dir.exists() and export_dir.is_dir():
|
||||
json_files = list(export_dir.glob("*.json"))
|
||||
if json_files:
|
||||
export_dirs.append(export_dir)
|
||||
print(f"Found WeChat export directory: {export_dir} with {len(json_files)} files")
|
||||
|
||||
print(
|
||||
f"Found WeChat export directory: {export_dir} with {len(json_files)} files"
|
||||
)
|
||||
|
||||
print(f"Found {len(export_dirs)} WeChat export directories")
|
||||
return export_dirs
|
||||
|
||||
@staticmethod
|
||||
def export_chat_to_file(output_file: str = "wechat_chat_export.txt", max_count: int = 1000, export_dir: str = None, include_non_text: bool = False):
|
||||
def export_chat_to_file(
|
||||
output_file: str = "wechat_chat_export.txt",
|
||||
max_count: int = 1000,
|
||||
export_dir: str | None = None,
|
||||
include_non_text: bool = False,
|
||||
):
|
||||
"""
|
||||
Export WeChat chat history to a text file.
|
||||
|
||||
|
||||
Args:
|
||||
output_file: Path to the output file
|
||||
max_count: Maximum number of entries to export
|
||||
@@ -552,36 +600,36 @@ Message: {readable_text if readable_text else message_text}
|
||||
"""
|
||||
if export_dir is None:
|
||||
export_dir = "./wechat_export_test"
|
||||
|
||||
|
||||
if not os.path.exists(export_dir):
|
||||
print(f"WeChat export directory not found at: {export_dir}")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
json_files = list(Path(export_dir).glob("*.json"))
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
count = 0
|
||||
for json_file in json_files:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
with open(json_file, 'r', encoding='utf-8') as json_f:
|
||||
with open(json_file, encoding="utf-8") as json_f:
|
||||
chat_data = json.load(json_f)
|
||||
|
||||
|
||||
contact_name = json_file.stem
|
||||
f.write(f"\n=== Chat with {contact_name} ===\n")
|
||||
|
||||
|
||||
for message in chat_data:
|
||||
if count >= max_count and max_count > 0:
|
||||
break
|
||||
|
||||
from_user = message.get('fromUser', '')
|
||||
content = message.get('content', '')
|
||||
message_text = message.get('message', '')
|
||||
create_time = message.get('createTime', 0)
|
||||
|
||||
|
||||
from_user = message.get("fromUser", "")
|
||||
content = message.get("content", "")
|
||||
message_text = message.get("message", "")
|
||||
create_time = message.get("createTime", 0)
|
||||
|
||||
# Skip non-text messages unless requested
|
||||
if not include_non_text:
|
||||
reader = WeChatHistoryReader()
|
||||
@@ -591,83 +639,90 @@ Message: {readable_text if readable_text else message_text}
|
||||
if not readable_text:
|
||||
continue
|
||||
message_text = readable_text
|
||||
|
||||
|
||||
if create_time:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(create_time)
|
||||
time_str = timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
time_str = str(create_time)
|
||||
else:
|
||||
time_str = "Unknown"
|
||||
|
||||
|
||||
f.write(f"[{time_str}] {from_user}: {message_text}\n")
|
||||
count += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {json_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
print(f"Exported {count} chat entries to {output_file}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error exporting WeChat chat history: {e}")
|
||||
|
||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Optional[Path]:
|
||||
def export_wechat_chat_history(self, export_dir: str = "./wechat_export_direct") -> Path | None:
|
||||
"""
|
||||
Export WeChat chat history using wechat-exporter tool.
|
||||
|
||||
|
||||
Args:
|
||||
export_dir: Directory to save exported chat history
|
||||
|
||||
|
||||
Returns:
|
||||
Path to export directory if successful, None otherwise
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
# Create export directory
|
||||
export_path = Path(export_dir)
|
||||
export_path.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
print(f"Exporting WeChat chat history to {export_path}...")
|
||||
|
||||
|
||||
# Check if wechat-exporter directory exists
|
||||
if not self.wechat_exporter_dir.exists():
|
||||
print(f"wechat-exporter directory not found at: {self.wechat_exporter_dir}")
|
||||
return None
|
||||
|
||||
|
||||
# Install requirements if needed
|
||||
requirements_file = self.wechat_exporter_dir / "requirements.txt"
|
||||
if requirements_file.exists():
|
||||
print("Installing wechat-exporter requirements...")
|
||||
subprocess.run([
|
||||
"uv", "pip", "install", "-r", str(requirements_file)
|
||||
], check=True)
|
||||
|
||||
subprocess.run(["uv", "pip", "install", "-r", str(requirements_file)], check=True)
|
||||
|
||||
# Run the export command
|
||||
print("Running wechat-exporter...")
|
||||
result = subprocess.run([
|
||||
sys.executable, str(self.wechat_exporter_dir / "main.py"),
|
||||
"export-all", str(export_path)
|
||||
], capture_output=True, text=True, check=True)
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.wechat_exporter_dir / "main.py"),
|
||||
"export-all",
|
||||
str(export_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
print("Export command output:")
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print("Export errors:")
|
||||
print(result.stderr)
|
||||
|
||||
|
||||
# Check if export was successful
|
||||
if export_path.exists() and any(export_path.glob("*.json")):
|
||||
json_files = list(export_path.glob("*.json"))
|
||||
print(f"Successfully exported {len(json_files)} chat history files to {export_path}")
|
||||
print(
|
||||
f"Successfully exported {len(json_files)} chat history files to {export_path}"
|
||||
)
|
||||
return export_path
|
||||
else:
|
||||
print("Export completed but no JSON files found")
|
||||
return None
|
||||
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Export command failed: {e}")
|
||||
print(f"Command output: {e.stdout}")
|
||||
@@ -678,18 +733,18 @@ Message: {readable_text if readable_text else message_text}
|
||||
print("Please ensure WeChat is running and WeChatTweak is installed.")
|
||||
return None
|
||||
|
||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> List[Path]:
|
||||
def find_or_export_wechat_data(self, export_dir: str = "./wechat_export_direct") -> list[Path]:
|
||||
"""
|
||||
Find existing WeChat exports or create new ones.
|
||||
|
||||
|
||||
Args:
|
||||
export_dir: Directory to save exported chat history if needed
|
||||
|
||||
|
||||
Returns:
|
||||
List of Path objects pointing to WeChat export directories
|
||||
"""
|
||||
export_dirs = []
|
||||
|
||||
|
||||
# Look for existing exports in common locations
|
||||
possible_export_dirs = [
|
||||
Path("./wechat_database_export"),
|
||||
@@ -697,23 +752,25 @@ Message: {readable_text if readable_text else message_text}
|
||||
Path("./wechat_export"),
|
||||
Path("./wechat_export_direct"),
|
||||
Path("./wechat_chat_history"),
|
||||
Path("./chat_export")
|
||||
Path("./chat_export"),
|
||||
]
|
||||
|
||||
|
||||
for export_dir_path in possible_export_dirs:
|
||||
if export_dir_path.exists() and any(export_dir_path.glob("*.json")):
|
||||
export_dirs.append(export_dir_path)
|
||||
print(f"Found existing export: {export_dir_path}")
|
||||
|
||||
|
||||
# If no existing exports, try to export automatically
|
||||
if not export_dirs:
|
||||
print("No existing WeChat exports found. Starting direct export...")
|
||||
|
||||
|
||||
# Try to export using wechat-exporter
|
||||
exported_path = self.export_wechat_chat_history(export_dir)
|
||||
if exported_path:
|
||||
export_dirs = [exported_path]
|
||||
else:
|
||||
print("Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed.")
|
||||
|
||||
return export_dirs
|
||||
print(
|
||||
"Failed to export WeChat data. Please ensure WeChat is running and WeChatTweak is installed."
|
||||
)
|
||||
|
||||
return export_dirs
|
||||
219
apps/image_rag.py
Normal file
219
apps/image_rag.py
Normal file
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
CLIP Image RAG Application
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on images using CLIP embeddings.
|
||||
You can index a directory of images and search them using text queries.
|
||||
|
||||
Usage:
|
||||
python -m apps.image_rag --image-dir ./my_images/ --query "a sunset over mountains"
|
||||
python -m apps.image_rag --image-dir ./my_images/ --interactive
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
|
||||
|
||||
class ImageRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for images using CLIP embeddings.
|
||||
|
||||
This class provides a complete RAG pipeline for image data, including
|
||||
CLIP embedding generation, indexing, and text-based image search.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Image RAG",
|
||||
description="RAG application for images using CLIP embeddings",
|
||||
default_index_name="image_index",
|
||||
)
|
||||
# Override default embedding model to use CLIP
|
||||
self.embedding_model_default = "clip-ViT-L-14"
|
||||
self.embedding_mode_default = "sentence-transformers"
|
||||
self._image_data: list[dict] = []
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add image-specific arguments."""
|
||||
image_group = parser.add_argument_group("Image Parameters")
|
||||
image_group.add_argument(
|
||||
"--image-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing images to index",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--image-extensions",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=[".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"],
|
||||
help="Image file extensions to process (default: .jpg .jpeg .png .gif .bmp .webp)",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for CLIP embedding generation (default: 32)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load images, generate CLIP embeddings, and return text descriptions."""
|
||||
self._image_data = self._load_images_and_embeddings(args)
|
||||
return [entry["text"] for entry in self._image_data]
|
||||
|
||||
def _load_images_and_embeddings(self, args) -> list[dict]:
|
||||
"""Helper to process images and produce embeddings/metadata."""
|
||||
image_dir = Path(args.image_dir)
|
||||
if not image_dir.exists():
|
||||
raise ValueError(f"Image directory does not exist: {image_dir}")
|
||||
|
||||
print(f"📸 Loading images from {image_dir}...")
|
||||
|
||||
# Find all image files
|
||||
image_files = []
|
||||
for ext in args.image_extensions:
|
||||
image_files.extend(image_dir.rglob(f"*{ext}"))
|
||||
image_files.extend(image_dir.rglob(f"*{ext.upper()}"))
|
||||
|
||||
if not image_files:
|
||||
raise ValueError(
|
||||
f"No images found in {image_dir} with extensions {args.image_extensions}"
|
||||
)
|
||||
|
||||
print(f"✅ Found {len(image_files)} images")
|
||||
|
||||
# Limit if max_items is set
|
||||
if args.max_items > 0:
|
||||
image_files = image_files[: args.max_items]
|
||||
print(f"📊 Processing {len(image_files)} images (limited by --max-items)")
|
||||
|
||||
# Load CLIP model
|
||||
print("🔍 Loading CLIP model...")
|
||||
model = SentenceTransformer(self.embedding_model_default)
|
||||
|
||||
# Process images and generate embeddings
|
||||
print("🖼️ Processing images and generating embeddings...")
|
||||
image_data = []
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
for image_path in tqdm(image_files, desc="Processing images"):
|
||||
try:
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
batch_images.append(image)
|
||||
batch_paths.append(image_path)
|
||||
|
||||
# Process in batches
|
||||
if len(batch_images) >= args.batch_size:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=args.batch_size,
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
batch_images = []
|
||||
batch_paths = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to process {image_path}: {e}")
|
||||
continue
|
||||
|
||||
# Process remaining images
|
||||
if batch_images:
|
||||
embeddings = model.encode(
|
||||
batch_images,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=len(batch_images),
|
||||
show_progress_bar=False,
|
||||
)
|
||||
|
||||
for img_path, embedding in zip(batch_paths, embeddings):
|
||||
image_data.append(
|
||||
{
|
||||
"text": f"Image: {img_path.name}\nPath: {img_path}",
|
||||
"metadata": {
|
||||
"image_path": str(img_path),
|
||||
"image_name": img_path.name,
|
||||
"image_dir": str(image_dir),
|
||||
},
|
||||
"embedding": embedding.astype(np.float32),
|
||||
}
|
||||
)
|
||||
|
||||
print(f"✅ Processed {len(image_data)} images")
|
||||
return image_data
|
||||
|
||||
async def build_index(self, args, texts: list[dict[str, Any]]) -> str:
|
||||
"""Build index using pre-computed CLIP embeddings."""
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if not self._image_data or len(self._image_data) != len(texts):
|
||||
raise RuntimeError("No image data found. Make sure load_data() ran successfully.")
|
||||
|
||||
print("🔨 Building LEANN index with CLIP embeddings...")
|
||||
builder = LeannBuilder(
|
||||
backend_name=args.backend_name,
|
||||
embedding_model=self.embedding_model_default,
|
||||
embedding_mode=self.embedding_mode_default,
|
||||
is_recompute=False,
|
||||
distance_metric="cosine",
|
||||
graph_degree=args.graph_degree,
|
||||
build_complexity=args.build_complexity,
|
||||
is_compact=not args.no_compact,
|
||||
)
|
||||
|
||||
for text, data in zip(texts, self._image_data):
|
||||
builder.add_text(text=text, metadata=data["metadata"])
|
||||
|
||||
ids = [str(i) for i in range(len(self._image_data))]
|
||||
embeddings = np.array([data["embedding"] for data in self._image_data], dtype=np.float32)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=".pkl", delete=False) as f:
|
||||
pickle.dump((ids, embeddings), f)
|
||||
pkl_path = f.name
|
||||
|
||||
try:
|
||||
index_path = str(Path(args.index_dir) / f"{self.default_index_name}.leann")
|
||||
builder.build_index_from_embeddings(index_path, pkl_path)
|
||||
print(f"✅ Index built successfully at {index_path}")
|
||||
return index_path
|
||||
finally:
|
||||
Path(pkl_path).unlink()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for the image RAG application."""
|
||||
import asyncio
|
||||
|
||||
app = ImageRAG()
|
||||
asyncio.run(app.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
apps/imessage_data/__init__.py
Normal file
1
apps/imessage_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""iMessage data processing module."""
|
||||
342
apps/imessage_data/imessage_reader.py
Normal file
342
apps/imessage_data/imessage_reader.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
iMessage data reader.
|
||||
|
||||
Reads and processes iMessage conversation data from the macOS Messages database.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core import Document
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
|
||||
class IMessageReader(BaseReader):
|
||||
"""
|
||||
iMessage data reader.
|
||||
|
||||
Reads iMessage conversation data from the macOS Messages database (chat.db).
|
||||
Processes conversations into structured documents with metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, concatenate_conversations: bool = True) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
Args:
|
||||
concatenate_conversations: Whether to concatenate messages within conversations for better context
|
||||
"""
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
|
||||
def _get_default_chat_db_path(self) -> Path:
|
||||
"""
|
||||
Get the default path to the iMessage chat database.
|
||||
|
||||
Returns:
|
||||
Path to the chat.db file
|
||||
"""
|
||||
home = Path.home()
|
||||
return home / "Library" / "Messages" / "chat.db"
|
||||
|
||||
def _convert_cocoa_timestamp(self, cocoa_timestamp: int) -> str:
|
||||
"""
|
||||
Convert Cocoa timestamp to readable format.
|
||||
|
||||
Args:
|
||||
cocoa_timestamp: Timestamp in Cocoa format (nanoseconds since 2001-01-01)
|
||||
|
||||
Returns:
|
||||
Formatted timestamp string
|
||||
"""
|
||||
if cocoa_timestamp == 0:
|
||||
return "Unknown"
|
||||
|
||||
try:
|
||||
# Cocoa timestamp is nanoseconds since 2001-01-01 00:00:00 UTC
|
||||
# Convert to seconds and add to Unix epoch
|
||||
cocoa_epoch = datetime(2001, 1, 1)
|
||||
unix_timestamp = cocoa_timestamp / 1_000_000_000 # Convert nanoseconds to seconds
|
||||
message_time = cocoa_epoch.timestamp() + unix_timestamp
|
||||
return datetime.fromtimestamp(message_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "Unknown"
|
||||
|
||||
def _get_contact_name(self, handle_id: str) -> str:
|
||||
"""
|
||||
Get a readable contact name from handle ID.
|
||||
|
||||
Args:
|
||||
handle_id: The handle ID (phone number or email)
|
||||
|
||||
Returns:
|
||||
Formatted contact name
|
||||
"""
|
||||
if not handle_id:
|
||||
return "Unknown"
|
||||
|
||||
# Clean up phone numbers and emails for display
|
||||
if "@" in handle_id:
|
||||
return handle_id # Email address
|
||||
elif handle_id.startswith("+"):
|
||||
return handle_id # International phone number
|
||||
else:
|
||||
# Try to format as phone number
|
||||
digits = "".join(filter(str.isdigit, handle_id))
|
||||
if len(digits) == 10:
|
||||
return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
|
||||
elif len(digits) == 11 and digits[0] == "1":
|
||||
return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
|
||||
else:
|
||||
return handle_id
|
||||
|
||||
def _read_messages_from_db(self, db_path: Path) -> list[dict]:
|
||||
"""
|
||||
Read messages from the iMessage database.
|
||||
|
||||
Args:
|
||||
db_path: Path to the chat.db file
|
||||
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
if not db_path.exists():
|
||||
print(f"iMessage database not found at: {db_path}")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Connect to the database
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query to get messages with chat and handle information
|
||||
query = """
|
||||
SELECT
|
||||
m.ROWID as message_id,
|
||||
m.text,
|
||||
m.date,
|
||||
m.is_from_me,
|
||||
m.service,
|
||||
c.chat_identifier,
|
||||
c.display_name as chat_display_name,
|
||||
h.id as handle_id,
|
||||
c.ROWID as chat_id
|
||||
FROM message m
|
||||
LEFT JOIN chat_message_join cmj ON m.ROWID = cmj.message_id
|
||||
LEFT JOIN chat c ON cmj.chat_id = c.ROWID
|
||||
LEFT JOIN handle h ON m.handle_id = h.ROWID
|
||||
WHERE m.text IS NOT NULL AND m.text != ''
|
||||
ORDER BY c.ROWID, m.date
|
||||
"""
|
||||
|
||||
cursor.execute(query)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
messages = []
|
||||
for row in rows:
|
||||
(
|
||||
message_id,
|
||||
text,
|
||||
date,
|
||||
is_from_me,
|
||||
service,
|
||||
chat_identifier,
|
||||
chat_display_name,
|
||||
handle_id,
|
||||
chat_id,
|
||||
) = row
|
||||
|
||||
message = {
|
||||
"message_id": message_id,
|
||||
"text": text,
|
||||
"timestamp": self._convert_cocoa_timestamp(date),
|
||||
"is_from_me": bool(is_from_me),
|
||||
"service": service or "iMessage",
|
||||
"chat_identifier": chat_identifier or "Unknown",
|
||||
"chat_display_name": chat_display_name or "Unknown Chat",
|
||||
"handle_id": handle_id or "Unknown",
|
||||
"contact_name": self._get_contact_name(handle_id or ""),
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
messages.append(message)
|
||||
|
||||
conn.close()
|
||||
print(f"Found {len(messages)} messages in database")
|
||||
return messages
|
||||
|
||||
except sqlite3.Error as e:
|
||||
print(f"Error reading iMessage database: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"Unexpected error reading iMessage database: {e}")
|
||||
return []
|
||||
|
||||
def _group_messages_by_chat(self, messages: list[dict]) -> dict[int, list[dict]]:
|
||||
"""
|
||||
Group messages by chat ID.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Dictionary mapping chat_id to list of messages
|
||||
"""
|
||||
chats = {}
|
||||
for message in messages:
|
||||
chat_id = message["chat_id"]
|
||||
if chat_id not in chats:
|
||||
chats[chat_id] = []
|
||||
chats[chat_id].append(message)
|
||||
|
||||
return chats
|
||||
|
||||
def _create_concatenated_content(self, chat_id: int, messages: list[dict]) -> str:
|
||||
"""
|
||||
Create concatenated content from chat messages.
|
||||
|
||||
Args:
|
||||
chat_id: The chat ID
|
||||
messages: List of messages in the chat
|
||||
|
||||
Returns:
|
||||
Concatenated text content
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Get chat info from first message
|
||||
first_msg = messages[0]
|
||||
chat_name = first_msg["chat_display_name"]
|
||||
chat_identifier = first_msg["chat_identifier"]
|
||||
|
||||
# Build message content
|
||||
message_parts = []
|
||||
for message in messages:
|
||||
timestamp = message["timestamp"]
|
||||
is_from_me = message["is_from_me"]
|
||||
text = message["text"]
|
||||
contact_name = message["contact_name"]
|
||||
|
||||
if is_from_me:
|
||||
prefix = "[You]"
|
||||
else:
|
||||
prefix = f"[{contact_name}]"
|
||||
|
||||
if timestamp != "Unknown":
|
||||
prefix += f" ({timestamp})"
|
||||
|
||||
message_parts.append(f"{prefix}: {text}")
|
||||
|
||||
concatenated_text = "\n\n".join(message_parts)
|
||||
|
||||
doc_content = f"""Chat: {chat_name}
|
||||
Identifier: {chat_identifier}
|
||||
Messages ({len(messages)} messages):
|
||||
|
||||
{concatenated_text}
|
||||
"""
|
||||
return doc_content
|
||||
|
||||
def _create_individual_content(self, message: dict) -> str:
|
||||
"""
|
||||
Create content for individual message.
|
||||
|
||||
Args:
|
||||
message: Message dictionary
|
||||
|
||||
Returns:
|
||||
Formatted message content
|
||||
"""
|
||||
timestamp = message["timestamp"]
|
||||
is_from_me = message["is_from_me"]
|
||||
text = message["text"]
|
||||
contact_name = message["contact_name"]
|
||||
chat_name = message["chat_display_name"]
|
||||
|
||||
sender = "You" if is_from_me else contact_name
|
||||
|
||||
return f"""Message from {sender} in chat "{chat_name}"
|
||||
Time: {timestamp}
|
||||
Content: {text}
|
||||
"""
|
||||
|
||||
def load_data(self, input_dir: str | None = None, **load_kwargs: Any) -> list[Document]:
|
||||
"""
|
||||
Load iMessage data and return as documents.
|
||||
|
||||
Args:
|
||||
input_dir: Optional path to directory containing chat.db file.
|
||||
If not provided, uses default macOS location.
|
||||
**load_kwargs: Additional arguments (unused)
|
||||
|
||||
Returns:
|
||||
List of Document objects containing iMessage data
|
||||
"""
|
||||
docs = []
|
||||
|
||||
# Determine database path
|
||||
if input_dir:
|
||||
db_path = Path(input_dir) / "chat.db"
|
||||
else:
|
||||
db_path = self._get_default_chat_db_path()
|
||||
|
||||
print(f"Reading iMessage database from: {db_path}")
|
||||
|
||||
# Read messages from database
|
||||
messages = self._read_messages_from_db(db_path)
|
||||
if not messages:
|
||||
return docs
|
||||
|
||||
if self.concatenate_conversations:
|
||||
# Group messages by chat and create concatenated documents
|
||||
chats = self._group_messages_by_chat(messages)
|
||||
|
||||
for chat_id, chat_messages in chats.items():
|
||||
if not chat_messages:
|
||||
continue
|
||||
|
||||
content = self._create_concatenated_content(chat_id, chat_messages)
|
||||
|
||||
# Create metadata
|
||||
first_msg = chat_messages[0]
|
||||
last_msg = chat_messages[-1]
|
||||
|
||||
metadata = {
|
||||
"source": "iMessage",
|
||||
"chat_id": chat_id,
|
||||
"chat_name": first_msg["chat_display_name"],
|
||||
"chat_identifier": first_msg["chat_identifier"],
|
||||
"message_count": len(chat_messages),
|
||||
"first_message_date": first_msg["timestamp"],
|
||||
"last_message_date": last_msg["timestamp"],
|
||||
"participants": list(
|
||||
{msg["contact_name"] for msg in chat_messages if not msg["is_from_me"]}
|
||||
),
|
||||
}
|
||||
|
||||
doc = Document(text=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
else:
|
||||
# Create individual documents for each message
|
||||
for message in messages:
|
||||
content = self._create_individual_content(message)
|
||||
|
||||
metadata = {
|
||||
"source": "iMessage",
|
||||
"message_id": message["message_id"],
|
||||
"chat_id": message["chat_id"],
|
||||
"chat_name": message["chat_display_name"],
|
||||
"chat_identifier": message["chat_identifier"],
|
||||
"timestamp": message["timestamp"],
|
||||
"is_from_me": message["is_from_me"],
|
||||
"contact_name": message["contact_name"],
|
||||
"service": message["service"],
|
||||
}
|
||||
|
||||
doc = Document(text=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
print(f"Created {len(docs)} documents from iMessage data")
|
||||
return docs
|
||||
126
apps/imessage_rag.py
Normal file
126
apps/imessage_rag.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
iMessage RAG Example.
|
||||
|
||||
This example demonstrates how to build a RAG system on your iMessage conversation history.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from leann.chunking_utils import create_text_chunks
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.imessage_data.imessage_reader import IMessageReader
|
||||
|
||||
|
||||
class IMessageRAG(BaseRAGExample):
|
||||
"""RAG example for iMessage conversation history."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="iMessage",
|
||||
description="RAG on your iMessage conversation history",
|
||||
default_index_name="imessage_index",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add iMessage-specific arguments."""
|
||||
imessage_group = parser.add_argument_group("iMessage Parameters")
|
||||
imessage_group.add_argument(
|
||||
"--db-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to iMessage chat.db file (default: ~/Library/Messages/chat.db)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Concatenate messages within conversations for better context (default: True)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--no-concatenate-conversations",
|
||||
action="store_true",
|
||||
help="Process each message individually instead of concatenating by conversation",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum characters per text chunk (default: 1000)",
|
||||
)
|
||||
imessage_group.add_argument(
|
||||
"--chunk-overlap",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Overlap between text chunks (default: 200)",
|
||||
)
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load iMessage history and convert to text chunks."""
|
||||
print("Loading iMessage conversation history...")
|
||||
|
||||
# Determine concatenation setting
|
||||
concatenate = args.concatenate_conversations and not args.no_concatenate_conversations
|
||||
|
||||
# Initialize iMessage reader
|
||||
reader = IMessageReader(concatenate_conversations=concatenate)
|
||||
|
||||
# Load documents
|
||||
try:
|
||||
if args.db_path:
|
||||
# Use custom database path
|
||||
db_dir = str(Path(args.db_path).parent)
|
||||
documents = reader.load_data(input_dir=db_dir)
|
||||
else:
|
||||
# Use default macOS location
|
||||
documents = reader.load_data()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading iMessage data: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure you have granted Full Disk Access to your terminal/IDE")
|
||||
print("2. Check that the iMessage database exists at ~/Library/Messages/chat.db")
|
||||
print("3. Try specifying a custom path with --db-path if you have a backup")
|
||||
return []
|
||||
|
||||
if not documents:
|
||||
print("No iMessage conversations found!")
|
||||
return []
|
||||
|
||||
print(f"Loaded {len(documents)} iMessage documents")
|
||||
|
||||
# Show some statistics
|
||||
total_messages = sum(doc.metadata.get("message_count", 1) for doc in documents)
|
||||
print(f"Total messages: {total_messages}")
|
||||
|
||||
if concatenate:
|
||||
# Show chat statistics
|
||||
chat_names = [doc.metadata.get("chat_name", "Unknown") for doc in documents]
|
||||
unique_chats = len(set(chat_names))
|
||||
print(f"Unique conversations: {unique_chats}")
|
||||
|
||||
# Convert to text chunks
|
||||
all_texts = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=args.chunk_size,
|
||||
chunk_overlap=args.chunk_overlap,
|
||||
)
|
||||
|
||||
# Apply max_items limit if specified
|
||||
if args.max_items > 0:
|
||||
all_texts = all_texts[: args.max_items]
|
||||
print(f"Limited to {len(all_texts)} text chunks (max_items={args.max_items})")
|
||||
|
||||
return all_texts
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
app = IMessageRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
113
apps/multimodal/vision-based-pdf-multi-vector/README.md
Normal file
113
apps/multimodal/vision-based-pdf-multi-vector/README.md
Normal file
@@ -0,0 +1,113 @@
|
||||
## Vision-based PDF Multi-Vector Demos (macOS/MPS)
|
||||
|
||||
This folder contains two demos to index PDF pages as images and run multi-vector retrieval with ColPali/ColQwen2, plus optional similarity map visualization and answer generation.
|
||||
|
||||
### What you’ll run
|
||||
- `multi-vector-leann-paper-example.py`: local PDF → pages → embed → build HNSW index → search.
|
||||
- `multi-vector-leann-similarity-map.py`: HF dataset (default) or local pages → embed → index → retrieve → similarity maps → optional Qwen-VL answer.
|
||||
|
||||
## Prerequisites (macOS)
|
||||
|
||||
### 1) Homebrew poppler (for pdf2image)
|
||||
```bash
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### 2) Python environment
|
||||
Use uv (recommended) or pip. Python 3.9+.
|
||||
|
||||
Using uv:
|
||||
```bash
|
||||
uv pip install \
|
||||
colpali_engine \
|
||||
pdf2image \
|
||||
pillow \
|
||||
matplotlib qwen_vl_utils \
|
||||
einops \
|
||||
seaborn
|
||||
```
|
||||
|
||||
Notes:
|
||||
- On first run, models download from Hugging Face. Login/config if needed.
|
||||
- The scripts auto-select device: CUDA > MPS > CPU. Verify MPS:
|
||||
```bash
|
||||
python -c "import torch; print('MPS available:', bool(getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available()))"
|
||||
```
|
||||
|
||||
## Run the demos
|
||||
|
||||
### A) Local PDF example
|
||||
Converts a local PDF into page images, embeds them, builds an index, and searches.
|
||||
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
# If you don't have the sample PDF locally, download it (ignored by Git)
|
||||
mkdir -p pdfs
|
||||
curl -L -o pdfs/2004.12832v2.pdf https://arxiv.org/pdf/2004.12832.pdf
|
||||
ls pdfs/2004.12832v2.pdf
|
||||
# Ensure output dir exists
|
||||
mkdir -p pages
|
||||
python multi-vector-leann-paper-example.py
|
||||
```
|
||||
Expected:
|
||||
- Page images in `pages/`.
|
||||
- Console prints like `Using device=mps, dtype=...` and retrieved file paths for queries.
|
||||
|
||||
To use your own PDF: edit `pdf_path` near the top of the script.
|
||||
|
||||
### B) Similarity map + answer demo
|
||||
Uses HF dataset `weaviate/arXiv-AI-papers-multi-vector` by default; can switch to local pages.
|
||||
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
Artifacts (when enabled):
|
||||
- Retrieved pages: `./figures/retrieved_page_rank{K}.png`
|
||||
- Similarity maps: `./figures/similarity_map_rank{K}.png`
|
||||
|
||||
Key knobs in the script (top of file):
|
||||
- `QUERY`: your question
|
||||
- `MODEL`: `"colqwen2"` or `"colpali"`
|
||||
- `USE_HF_DATASET`: set `False` to use local pages
|
||||
- `PDF`, `PAGES_DIR`: for local mode
|
||||
- `INDEX_PATH`, `TOPK`, `FIRST_STAGE_K`, `REBUILD_INDEX`
|
||||
- `SIMILARITY_MAP`, `SIM_TOKEN_IDX`, `SIM_OUTPUT`
|
||||
- `ANSWER`, `MAX_NEW_TOKENS` (Qwen-VL)
|
||||
|
||||
## Troubleshooting
|
||||
- pdf2image errors on macOS: ensure `brew install poppler` and `pdfinfo` works in terminal.
|
||||
- Slow or OOM on MPS: reduce dataset size (e.g., set `MAX_DOCS`) or switch to CPU.
|
||||
- NaNs on MPS: keep fp32 on MPS (default in similarity-map script); avoid fp16 there.
|
||||
- First-run model downloads can be large; ensure network access (HF mirrors if needed).
|
||||
|
||||
## Notes
|
||||
- Index files are under `./indexes/`. Delete or set `REBUILD_INDEX=True` to rebuild.
|
||||
- For local PDFs, page images go to `./pages/`.
|
||||
|
||||
|
||||
### Retrieval and Visualization Example
|
||||
|
||||
Example settings in `multi-vector-leann-similarity-map.py`:
|
||||
- `QUERY = "How does DeepSeek-V2 compare against the LLaMA family of LLMs?"`
|
||||
- `SIMILARITY_MAP = True` (to generate heatmaps)
|
||||
- `TOPK = 1` (save the top retrieved page and its similarity map)
|
||||
|
||||
Run:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Outputs (by default):
|
||||
- Retrieved page: `./figures/retrieved_page_rank1.png`
|
||||
- Similarity map: `./figures/similarity_map_rank1.png`
|
||||
|
||||
Sample visualization (example result, and the query is "QUERY = "How does Vim model performance and efficiency compared to other models?"
|
||||
"):
|
||||

|
||||
|
||||
Notes:
|
||||
- Set `SIM_TOKEN_IDX` to visualize a specific token index; set `-1` to auto-select the most salient token.
|
||||
- If you change `SIM_OUTPUT` to a file path (e.g., `./figures/my_map.png`), multiple ranks are saved as `my_map_rank{K}.png`.
|
||||
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
132
apps/multimodal/vision-based-pdf-multi-vector/colqwen_forward.py
Executable file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple test script to test colqwen2 forward pass with a single image."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the current directory to path to import leann_multi_vector
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from leann_multi_vector import _embed_images, _ensure_repo_paths_importable, _load_colvision
|
||||
from PIL import Image
|
||||
|
||||
# Ensure repo paths are importable
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Set environment variable
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def create_test_image():
|
||||
"""Create a simple test image."""
|
||||
# Create a simple RGB image (800x600)
|
||||
img = Image.new("RGB", (800, 600), color="white")
|
||||
return img
|
||||
|
||||
|
||||
def load_test_image_from_file():
|
||||
"""Try to load an image from the indexes directory if available."""
|
||||
# Try to find an existing image in the indexes directory
|
||||
indexes_dir = Path(__file__).parent / "indexes"
|
||||
|
||||
# Look for images in common locations
|
||||
possible_paths = [
|
||||
indexes_dir / "vidore_fastplaid" / "images",
|
||||
indexes_dir / "colvision_large.leann.images",
|
||||
indexes_dir / "colvision.leann.images",
|
||||
]
|
||||
|
||||
for img_dir in possible_paths:
|
||||
if img_dir.exists():
|
||||
# Find first image file
|
||||
for ext in [".png", ".jpg", ".jpeg"]:
|
||||
for img_file in img_dir.glob(f"*{ext}"):
|
||||
print(f"Loading test image from: {img_file}")
|
||||
return Image.open(img_file)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("Testing ColQwen2 Forward Pass")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Load or create test image
|
||||
print("\n[Step 1] Loading test image...")
|
||||
test_image = load_test_image_from_file()
|
||||
if test_image is None:
|
||||
print("No existing image found, creating a simple test image...")
|
||||
test_image = create_test_image()
|
||||
else:
|
||||
print(f"✓ Loaded image: {test_image.size} ({test_image.mode})")
|
||||
|
||||
# Convert to RGB if needed
|
||||
if test_image.mode != "RGB":
|
||||
test_image = test_image.convert("RGB")
|
||||
print(f"✓ Converted to RGB: {test_image.size}")
|
||||
|
||||
# Step 2: Load model
|
||||
print("\n[Step 2] Loading ColQwen2 model...")
|
||||
try:
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision("colqwen2")
|
||||
print(f"✓ Model loaded: {model_name}")
|
||||
print(f"✓ Device: {device_str}, dtype: {dtype}")
|
||||
|
||||
# Print model info
|
||||
if hasattr(model, "device"):
|
||||
print(f"✓ Model device: {model.device}")
|
||||
if hasattr(model, "dtype"):
|
||||
print(f"✓ Model dtype: {model.dtype}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# Step 3: Test forward pass
|
||||
print("\n[Step 3] Running forward pass...")
|
||||
try:
|
||||
# Use the _embed_images function which handles batching and forward pass
|
||||
images = [test_image]
|
||||
print(f"Processing {len(images)} image(s)...")
|
||||
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
|
||||
print("✓ Forward pass completed!")
|
||||
print(f"✓ Number of embeddings: {len(doc_vecs)}")
|
||||
|
||||
if len(doc_vecs) > 0:
|
||||
emb = doc_vecs[0]
|
||||
print(f"✓ Embedding shape: {emb.shape}")
|
||||
print(f"✓ Embedding dtype: {emb.dtype}")
|
||||
print("✓ Embedding stats:")
|
||||
print(f" - Min: {emb.min().item():.4f}")
|
||||
print(f" - Max: {emb.max().item():.4f}")
|
||||
print(f" - Mean: {emb.mean().item():.4f}")
|
||||
print(f" - Std: {emb.std().item():.4f}")
|
||||
|
||||
# Check for NaN or Inf
|
||||
if torch.isnan(emb).any():
|
||||
print("⚠ Warning: Embedding contains NaN values!")
|
||||
if torch.isinf(emb).any():
|
||||
print("⚠ Warning: Embedding contains Inf values!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error during forward pass: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
apps/multimodal/vision-based-pdf-multi-vector/fig/image.png
Normal file
BIN
apps/multimodal/vision-based-pdf-multi-vector/fig/image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 166 KiB |
1452
apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py
Normal file
1452
apps/multimodal/vision-based-pdf-multi-vector/leann_multi_vector.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,113 @@
|
||||
# pip install pdf2image
|
||||
# pip install pymilvus
|
||||
# pip install colpali_engine
|
||||
# pip install tqdm
|
||||
# pip install pillow
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Ensure local leann packages are importable before importing them
|
||||
_repo_root = Path(__file__).resolve().parents[3]
|
||||
_leann_core_src = _repo_root / "packages" / "leann-core" / "src"
|
||||
_leann_hnsw_pkg = _repo_root / "packages" / "leann-backend-hnsw"
|
||||
if str(_leann_core_src) not in sys.path:
|
||||
sys.path.insert(0, str(_leann_core_src))
|
||||
if str(_leann_hnsw_pkg) not in sys.path:
|
||||
sys.path.insert(0, str(_leann_hnsw_pkg))
|
||||
|
||||
from leann_multi_vector import LeannMultiVector
|
||||
|
||||
import torch
|
||||
from colpali_engine.models import ColPali
|
||||
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
||||
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Auto-select device: CUDA > MPS (mac) > CPU
|
||||
_device_str = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
"mps"
|
||||
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
)
|
||||
device = get_torch_device(_device_str)
|
||||
# Prefer fp16 on GPU/MPS, bfloat16 on CPU
|
||||
_dtype = torch.float16 if _device_str in ("cuda", "mps") else torch.bfloat16
|
||||
model_name = "vidore/colpali-v1.2"
|
||||
|
||||
model = ColPali.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=_dtype,
|
||||
device_map=device,
|
||||
).eval()
|
||||
print(f"Using device={_device_str}, dtype={_dtype}")
|
||||
|
||||
queries = [
|
||||
"How to end-to-end retrieval with ColBert",
|
||||
"Where is ColBERT performance Table, including text representation results?",
|
||||
]
|
||||
|
||||
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](queries),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_queries(x),
|
||||
)
|
||||
|
||||
qs: list[torch.Tensor] = []
|
||||
for batch_query in dataloader:
|
||||
with torch.no_grad():
|
||||
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
||||
embeddings_query = model(**batch_query)
|
||||
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
|
||||
print(qs[0].shape)
|
||||
# %%
|
||||
page_filenames = sorted(os.listdir("./pages"), key=lambda n: int(re.search(r"\d+", n).group()))
|
||||
images = [Image.open(os.path.join("./pages", name)) for name in page_filenames]
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset=ListDataset[str](images),
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: processor.process_images(x),
|
||||
)
|
||||
|
||||
ds: list[torch.Tensor] = []
|
||||
for batch_doc in tqdm(dataloader):
|
||||
with torch.no_grad():
|
||||
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
||||
embeddings_doc = model(**batch_doc)
|
||||
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
||||
|
||||
print(ds[0].shape)
|
||||
|
||||
# %%
|
||||
# Build HNSW index via LeannMultiVector primitives and run search
|
||||
index_path = "./indexes/colpali.leann"
|
||||
retriever = LeannMultiVector(index_path=index_path, dim=int(ds[0].shape[-1]))
|
||||
retriever.create_collection()
|
||||
filepaths = [os.path.join("./pages", name) for name in page_filenames]
|
||||
for i in range(len(filepaths)):
|
||||
data = {
|
||||
"colbert_vecs": ds[i].float().numpy(),
|
||||
"doc_id": i,
|
||||
"filepath": filepaths[i],
|
||||
}
|
||||
retriever.insert(data)
|
||||
retriever.create_index()
|
||||
for query in qs:
|
||||
query_np = query.float().numpy()
|
||||
result = retriever.search(query_np, topk=1)
|
||||
print(filepaths[result[0][1]])
|
||||
@@ -0,0 +1,713 @@
|
||||
## Jupyter-style notebook script
|
||||
# %%
|
||||
# uv pip install matplotlib qwen_vl_utils
|
||||
import argparse
|
||||
import faulthandler
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
# Enable faulthandler to get stack trace on segfault
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
from leann_multi_vector import ( # utility functions/classes
|
||||
_ensure_repo_paths_importable,
|
||||
_load_images_from_dir,
|
||||
_maybe_convert_pdf_to_images,
|
||||
_load_colvision,
|
||||
_embed_images,
|
||||
_embed_queries,
|
||||
_build_index,
|
||||
_load_retriever_if_index_exists,
|
||||
_generate_similarity_map,
|
||||
_build_fast_plaid_index,
|
||||
_load_fast_plaid_index_if_exists,
|
||||
_search_fast_plaid,
|
||||
_get_fast_plaid_image,
|
||||
_get_fast_plaid_metadata,
|
||||
QwenVL,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# %%
|
||||
# Config
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
QUERY = "The paper talk about the latent video generative model and data curation in the related work part?"
|
||||
MODEL: str = "colqwen2" # "colpali" or "colqwen2"
|
||||
|
||||
# Data source: set to True to use the Hugging Face dataset example (recommended)
|
||||
USE_HF_DATASET: bool = True
|
||||
# Single dataset name (used when DATASET_NAMES is None)
|
||||
DATASET_NAME: str = "weaviate/arXiv-AI-papers-multi-vector"
|
||||
# Multiple datasets to combine (if provided, DATASET_NAME is ignored)
|
||||
# Can be:
|
||||
# - List of strings: ["dataset1", "dataset2"]
|
||||
# - List of tuples: [("dataset1", "config1"), ("dataset2", None)] # None = no config needed
|
||||
# - Mixed: ["dataset1", ("dataset2", "config2")]
|
||||
#
|
||||
# Some potential datasets with images (may need IMAGE_FIELD_NAME adjustment):
|
||||
# - "weaviate/arXiv-AI-papers-multi-vector" (current, has "page_image" field)
|
||||
# - ("lmms-lab/DocVQA", "DocVQA") (has "image" field, document images, needs config)
|
||||
# - ("lmms-lab/DocVQA", "InfographicVQA") (has "image" field, infographic images)
|
||||
# - "pixparse/arxiv-papers" (if available, arXiv papers)
|
||||
# - "allenai/ai2d" (AI2D diagram dataset, has "image" field)
|
||||
# - "huggingface/document-images" (if available)
|
||||
# Note: Check dataset structure first - some may need IMAGE_FIELD_NAME specified
|
||||
# DATASET_NAMES: Optional[list[str | tuple[str, Optional[str]]]] = None
|
||||
DATASET_NAMES = [
|
||||
"weaviate/arXiv-AI-papers-multi-vector",
|
||||
# ("lmms-lab/DocVQA", "DocVQA"), # Specify config name for datasets with multiple configs
|
||||
]
|
||||
# Load multiple splits to get more data (e.g., ["train", "test", "validation"])
|
||||
# Set to None to try loading all available splits automatically
|
||||
DATASET_SPLITS: Optional[list[str]] = ["train", "test"] # None = auto-detect all splits
|
||||
# Image field name in the dataset (auto-detect if None)
|
||||
# Common names: "page_image", "image", "images", "img"
|
||||
IMAGE_FIELD_NAME: Optional[str] = None # None = auto-detect
|
||||
MAX_DOCS: Optional[int] = None # limit number of pages to index; None = all
|
||||
|
||||
# Local pages (used when USE_HF_DATASET == False)
|
||||
PDF: Optional[str] = None # e.g., "./pdfs/2004.12832v2.pdf"
|
||||
PAGES_DIR: str = "./pages"
|
||||
# Custom folder path (takes precedence over USE_HF_DATASET and PAGES_DIR)
|
||||
# If set, images will be loaded directly from this folder
|
||||
CUSTOM_FOLDER_PATH: Optional[str] = None # e.g., "/home/ubuntu/dr-tulu/agent/screenshots"
|
||||
# Whether to recursively search subdirectories when loading from custom folder
|
||||
CUSTOM_FOLDER_RECURSIVE: bool = False # Set to True to search subdirectories
|
||||
|
||||
# Index + retrieval settings
|
||||
# Use a different index path for larger dataset to avoid overwriting existing index
|
||||
INDEX_PATH: str = "./indexes/colvision_large.leann"
|
||||
# Fast-Plaid index settings (alternative to LEANN index)
|
||||
# These are now command-line arguments (see CLI overrides section)
|
||||
TOPK: int = 3
|
||||
FIRST_STAGE_K: int = 500
|
||||
REBUILD_INDEX: bool = False # Set to True to force rebuild even if index exists
|
||||
|
||||
# Artifacts
|
||||
SAVE_TOP_IMAGE: Optional[str] = "./figures/retrieved_page.png"
|
||||
SIMILARITY_MAP: bool = True
|
||||
SIM_TOKEN_IDX: int = 13 # -1 means auto-select the most salient token
|
||||
SIM_OUTPUT: str = "./figures/similarity_map.png"
|
||||
ANSWER: bool = True
|
||||
MAX_NEW_TOKENS: int = 1024
|
||||
|
||||
|
||||
# %%
|
||||
# CLI overrides
|
||||
parser = argparse.ArgumentParser(description="Multi-vector LEANN similarity map demo")
|
||||
parser.add_argument(
|
||||
"--search-method",
|
||||
type=str,
|
||||
choices=["ann", "exact", "exact-all"],
|
||||
default="ann",
|
||||
help="Which search method to use: 'ann' (fast ANN), 'exact' (ANN + exact rerank), or 'exact-all' (exact over all docs).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default=QUERY,
|
||||
help=f"Query string to search for. Default: '{QUERY}'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Set to True to use fast-plaid instead of LEANN. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default="./indexes/colvision_fastplaid",
|
||||
help="Path to the Fast-Plaid index. Default: './indexes/colvision_fastplaid'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--topk",
|
||||
type=int,
|
||||
default=TOPK,
|
||||
help=f"Number of top results to retrieve. Default: {TOPK}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--custom-folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a custom folder containing images to search. Takes precedence over dataset loading. Default: None",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recursive",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Recursively search subdirectories when loading images from custom folder. Default: False",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Force rebuild the index even if it already exists. Default: False (reuse existing index if available)",
|
||||
)
|
||||
cli_args, _unknown = parser.parse_known_args()
|
||||
SEARCH_METHOD: str = cli_args.search_method
|
||||
QUERY = cli_args.query # Override QUERY with CLI argument if provided
|
||||
USE_FAST_PLAID: bool = cli_args.use_fast_plaid
|
||||
FAST_PLAID_INDEX_PATH: str = cli_args.fast_plaid_index_path
|
||||
TOPK: int = cli_args.topk # Override TOPK with CLI argument if provided
|
||||
CUSTOM_FOLDER_PATH = cli_args.custom_folder if cli_args.custom_folder else CUSTOM_FOLDER_PATH # Override with CLI argument if provided
|
||||
CUSTOM_FOLDER_RECURSIVE = cli_args.recursive if cli_args.recursive else CUSTOM_FOLDER_RECURSIVE # Override with CLI argument if provided
|
||||
REBUILD_INDEX = cli_args.rebuild_index # Override REBUILD_INDEX with CLI argument
|
||||
|
||||
# %%
|
||||
|
||||
# Step 1: Check if we can skip data loading (index already exists)
|
||||
retriever: Optional[Any] = None
|
||||
fast_plaid_index: Optional[Any] = None
|
||||
need_to_build_index = REBUILD_INDEX
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid index handling
|
||||
if not REBUILD_INDEX:
|
||||
try:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is not None:
|
||||
print(f"✓ Fast-Plaid index found at {FAST_PLAID_INDEX_PATH}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Fast-Plaid index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
except Exception as e:
|
||||
# If loading fails (e.g., memory error, corrupted index), rebuild
|
||||
print(f"Warning: Failed to load Fast-Plaid index: {e}")
|
||||
print("Will rebuild the index...")
|
||||
need_to_build_index = True
|
||||
fast_plaid_index = None
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild Fast-Plaid index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
# Original LEANN index handling
|
||||
if not REBUILD_INDEX:
|
||||
retriever = _load_retriever_if_index_exists(INDEX_PATH)
|
||||
if retriever is not None:
|
||||
print(f"✓ Index loaded from {INDEX_PATH}")
|
||||
print(f"✓ Images available at: {retriever._images_dir_path()}")
|
||||
need_to_build_index = False
|
||||
else:
|
||||
print(f"Index not found, will build new index")
|
||||
need_to_build_index = True
|
||||
else:
|
||||
print(f"REBUILD_INDEX=True, will rebuild index")
|
||||
need_to_build_index = True
|
||||
|
||||
# Step 2: Load data only if we need to build the index
|
||||
if need_to_build_index:
|
||||
print("Loading dataset...")
|
||||
# Check for custom folder path first (takes precedence)
|
||||
if CUSTOM_FOLDER_PATH:
|
||||
if not os.path.isdir(CUSTOM_FOLDER_PATH):
|
||||
raise RuntimeError(f"Custom folder path does not exist: {CUSTOM_FOLDER_PATH}")
|
||||
print(f"Loading images from custom folder: {CUSTOM_FOLDER_PATH}")
|
||||
if CUSTOM_FOLDER_RECURSIVE:
|
||||
print(" (recursive mode: searching subdirectories)")
|
||||
filepaths, images = _load_images_from_dir(CUSTOM_FOLDER_PATH, recursive=CUSTOM_FOLDER_RECURSIVE)
|
||||
print(f" Found {len(filepaths)} image files")
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {CUSTOM_FOLDER_PATH}. Ensure the folder contains image files (.png, .jpg, .jpeg, .webp)."
|
||||
)
|
||||
print(f" Successfully loaded {len(images)} images")
|
||||
# Use filenames as identifiers instead of full paths for cleaner metadata
|
||||
filepaths = [os.path.basename(fp) for fp in filepaths]
|
||||
elif USE_HF_DATASET:
|
||||
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
|
||||
|
||||
# Determine which datasets to load
|
||||
if DATASET_NAMES is not None:
|
||||
dataset_names_to_load = DATASET_NAMES
|
||||
print(f"Loading {len(dataset_names_to_load)} datasets: {dataset_names_to_load}")
|
||||
else:
|
||||
dataset_names_to_load = [DATASET_NAME]
|
||||
print(f"Loading single dataset: {DATASET_NAME}")
|
||||
|
||||
# Load and combine datasets
|
||||
all_datasets_to_concat = []
|
||||
|
||||
for dataset_entry in dataset_names_to_load:
|
||||
# Handle both string and tuple formats
|
||||
if isinstance(dataset_entry, tuple):
|
||||
dataset_name, config_name = dataset_entry
|
||||
else:
|
||||
dataset_name = dataset_entry
|
||||
config_name = None
|
||||
|
||||
print(f"\nProcessing dataset: {dataset_name}" + (f" (config: {config_name})" if config_name else ""))
|
||||
|
||||
# Load dataset to check available splits
|
||||
# If config_name is provided, use it; otherwise try without config
|
||||
try:
|
||||
if config_name:
|
||||
dataset_dict = load_dataset(dataset_name, config_name)
|
||||
else:
|
||||
dataset_dict = load_dataset(dataset_name)
|
||||
except ValueError as e:
|
||||
if "Config name is missing" in str(e):
|
||||
# Try to get available configs and suggest
|
||||
from datasets import get_dataset_config_names
|
||||
try:
|
||||
available_configs = get_dataset_config_names(dataset_name)
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Available configs: {available_configs}. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Dataset '{dataset_name}' requires a config name. "
|
||||
f"Please specify as: ('{dataset_name}', 'config_name')"
|
||||
) from e
|
||||
raise
|
||||
|
||||
# Determine which splits to load
|
||||
if DATASET_SPLITS is None:
|
||||
# Auto-detect: try to load all available splits
|
||||
available_splits = list(dataset_dict.keys())
|
||||
print(f" Auto-detected splits: {available_splits}")
|
||||
splits_to_load = available_splits
|
||||
else:
|
||||
splits_to_load = DATASET_SPLITS
|
||||
|
||||
# Load and concatenate multiple splits for this dataset
|
||||
datasets_to_concat: list[Dataset] = []
|
||||
for split in splits_to_load:
|
||||
if split not in dataset_dict:
|
||||
print(f" Warning: Split '{split}' not found in dataset. Available splits: {list(dataset_dict.keys())}")
|
||||
continue
|
||||
split_dataset = cast(Dataset, dataset_dict[split])
|
||||
print(f" Loaded split '{split}': {len(split_dataset)} pages")
|
||||
datasets_to_concat.append(split_dataset)
|
||||
|
||||
if not datasets_to_concat:
|
||||
print(f" Warning: No valid splits found for {dataset_name}. Skipping.")
|
||||
continue
|
||||
|
||||
# Concatenate splits for this dataset
|
||||
if len(datasets_to_concat) > 1:
|
||||
combined_dataset = concatenate_datasets(datasets_to_concat)
|
||||
print(f" Concatenated {len(datasets_to_concat)} splits into {len(combined_dataset)} pages")
|
||||
else:
|
||||
combined_dataset = datasets_to_concat[0]
|
||||
|
||||
all_datasets_to_concat.append(combined_dataset)
|
||||
|
||||
if not all_datasets_to_concat:
|
||||
raise RuntimeError("No valid datasets or splits found.")
|
||||
|
||||
# Concatenate all datasets
|
||||
if len(all_datasets_to_concat) > 1:
|
||||
dataset = concatenate_datasets(all_datasets_to_concat)
|
||||
print(f"\nConcatenated {len(all_datasets_to_concat)} datasets into {len(dataset)} total pages")
|
||||
else:
|
||||
dataset = all_datasets_to_concat[0]
|
||||
|
||||
# Apply MAX_DOCS limit if specified
|
||||
N = len(dataset) if MAX_DOCS is None else min(MAX_DOCS, len(dataset))
|
||||
if N < len(dataset):
|
||||
print(f"Limiting to {N} pages (from {len(dataset)} total)")
|
||||
dataset = dataset.select(range(N))
|
||||
|
||||
# Auto-detect image field name if not specified
|
||||
if IMAGE_FIELD_NAME is None:
|
||||
# Check multiple samples to find the most common image field
|
||||
# (useful when datasets are merged and may have different field names)
|
||||
possible_image_fields = ["page_image", "image", "images", "img", "page", "document_image"]
|
||||
field_counts = {}
|
||||
|
||||
# Check first few samples to find image fields
|
||||
num_samples_to_check = min(10, len(dataset))
|
||||
for sample_idx in range(num_samples_to_check):
|
||||
sample = dataset[sample_idx]
|
||||
for field in possible_image_fields:
|
||||
if field in sample and sample[field] is not None:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
field_counts[field] = field_counts.get(field, 0) + 1
|
||||
|
||||
# Choose the most common field, or first found if tied
|
||||
if field_counts:
|
||||
image_field = max(field_counts.items(), key=lambda x: x[1])[0]
|
||||
print(f"Auto-detected image field: '{image_field}' (found in {field_counts[image_field]}/{num_samples_to_check} samples)")
|
||||
else:
|
||||
# Fallback: check first sample only
|
||||
sample = dataset[0]
|
||||
image_field = None
|
||||
for field in possible_image_fields:
|
||||
if field in sample:
|
||||
value = sample[field]
|
||||
if isinstance(value, Image.Image) or (hasattr(value, 'size') and hasattr(value, 'mode')):
|
||||
image_field = field
|
||||
break
|
||||
if image_field is None:
|
||||
raise RuntimeError(
|
||||
f"Could not auto-detect image field. Available fields: {list(sample.keys())}. "
|
||||
f"Please specify IMAGE_FIELD_NAME manually."
|
||||
)
|
||||
print(f"Auto-detected image field: '{image_field}'")
|
||||
else:
|
||||
image_field = IMAGE_FIELD_NAME
|
||||
if image_field not in dataset[0]:
|
||||
raise RuntimeError(
|
||||
f"Image field '{image_field}' not found. Available fields: {list(dataset[0].keys())}"
|
||||
)
|
||||
|
||||
filepaths: list[str] = []
|
||||
images: list[Image.Image] = []
|
||||
for i in tqdm(range(len(dataset)), desc="Loading dataset", total=len(dataset)):
|
||||
p = dataset[i]
|
||||
# Try to compose a descriptive identifier
|
||||
# Handle different dataset structures
|
||||
identifier_parts = []
|
||||
|
||||
# Helper function to safely get field value
|
||||
def safe_get(field_name, default=None):
|
||||
if field_name in p and p[field_name] is not None:
|
||||
return p[field_name]
|
||||
return default
|
||||
|
||||
# Try to get various identifier fields
|
||||
if safe_get("paper_arxiv_id"):
|
||||
identifier_parts.append(f"arXiv:{p['paper_arxiv_id']}")
|
||||
if safe_get("paper_title"):
|
||||
identifier_parts.append(f"title:{p['paper_title']}")
|
||||
if safe_get("page_number") is not None:
|
||||
try:
|
||||
identifier_parts.append(f"page:{int(p['page_number'])}")
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, use the raw value or skip
|
||||
if p['page_number']:
|
||||
identifier_parts.append(f"page:{p['page_number']}")
|
||||
if safe_get("page_id"):
|
||||
identifier_parts.append(f"id:{p['page_id']}")
|
||||
elif safe_get("questionId"):
|
||||
identifier_parts.append(f"qid:{p['questionId']}")
|
||||
elif safe_get("docId"):
|
||||
identifier_parts.append(f"docId:{p['docId']}")
|
||||
elif safe_get("id"):
|
||||
identifier_parts.append(f"id:{p['id']}")
|
||||
|
||||
# If no identifier parts found, create one from index
|
||||
if identifier_parts:
|
||||
identifier = "|".join(identifier_parts)
|
||||
else:
|
||||
# Create identifier from available fields or index
|
||||
fallback_parts = []
|
||||
# Try common fields that might exist
|
||||
for field in ["ucsf_document_id", "docId", "questionId", "id"]:
|
||||
if safe_get(field):
|
||||
fallback_parts.append(f"{field}:{p[field]}")
|
||||
break
|
||||
if fallback_parts:
|
||||
identifier = "|".join(fallback_parts) + f"|idx:{i}"
|
||||
else:
|
||||
identifier = f"doc_{i}"
|
||||
|
||||
filepaths.append(identifier)
|
||||
|
||||
# Get image - try detected field first, then fallback to other common fields
|
||||
img = None
|
||||
if image_field in p and p[image_field] is not None:
|
||||
img = p[image_field]
|
||||
else:
|
||||
# Fallback: try other common image field names
|
||||
for fallback_field in ["image", "page_image", "images", "img"]:
|
||||
if fallback_field in p and p[fallback_field] is not None:
|
||||
img = p[fallback_field]
|
||||
break
|
||||
|
||||
if img is None:
|
||||
raise RuntimeError(
|
||||
f"No image found for sample {i}. Available fields: {list(p.keys())}. "
|
||||
f"Expected field: {image_field}"
|
||||
)
|
||||
|
||||
# Ensure it's a PIL Image
|
||||
if not isinstance(img, Image.Image):
|
||||
if hasattr(img, 'convert'):
|
||||
img = img.convert('RGB')
|
||||
else:
|
||||
img = Image.fromarray(img) if hasattr(img, '__array__') else Image.open(img)
|
||||
images.append(img)
|
||||
else:
|
||||
_maybe_convert_pdf_to_images(PDF, PAGES_DIR)
|
||||
filepaths, images = _load_images_from_dir(PAGES_DIR)
|
||||
if not images:
|
||||
raise RuntimeError(
|
||||
f"No images found in {PAGES_DIR}. Provide PDF path in PDF variable or ensure images exist."
|
||||
)
|
||||
print(f"Loaded {len(images)} images")
|
||||
|
||||
# Memory check before loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f"Memory usage after loading images: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
print("Skipping dataset loading (using existing index)")
|
||||
filepaths = [] # Not needed when using existing index
|
||||
images = [] # Not needed when using existing index
|
||||
|
||||
|
||||
# %%
|
||||
# Step 3: Load model and processor (only if we need to build index or perform search)
|
||||
print("Step 3: Loading model and processor...")
|
||||
print(f" Model: {MODEL}")
|
||||
try:
|
||||
import sys
|
||||
print(f" Python version: {sys.version}")
|
||||
print(f" Python executable: {sys.executable}")
|
||||
|
||||
model_name, model, processor, device_str, device, dtype = _load_colvision(MODEL)
|
||||
print(f"✓ Using model={model_name}, device={device_str}, dtype={dtype}")
|
||||
|
||||
# Memory check after loading model
|
||||
try:
|
||||
import psutil
|
||||
import torch
|
||||
process = psutil.Process(os.getpid())
|
||||
mem_info = process.memory_info()
|
||||
print(f" Memory usage after loading model: {mem_info.rss / 1024 / 1024 / 1024:.2f} GB")
|
||||
if torch.cuda.is_available():
|
||||
print(f" GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
|
||||
print(f" GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading model: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
# %%
|
||||
|
||||
# %%
|
||||
# Step 4: Build index if needed
|
||||
if need_to_build_index:
|
||||
print("Step 4: Building index...")
|
||||
print(f" Number of images: {len(images)}")
|
||||
print(f" Number of filepaths: {len(filepaths)}")
|
||||
|
||||
try:
|
||||
print(" Embedding images...")
|
||||
doc_vecs = _embed_images(model, processor, images)
|
||||
print(f" Embedded {len(doc_vecs)} documents")
|
||||
print(f" First doc vec shape: {doc_vecs[0].shape if len(doc_vecs) > 0 else 'N/A'}")
|
||||
except Exception as e:
|
||||
print(f"Error embedding images: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
if USE_FAST_PLAID:
|
||||
# Build Fast-Plaid index
|
||||
print(" Building Fast-Plaid index...")
|
||||
try:
|
||||
fast_plaid_index, build_secs = _build_fast_plaid_index(
|
||||
FAST_PLAID_INDEX_PATH, doc_vecs, filepaths, images
|
||||
)
|
||||
from pathlib import Path
|
||||
print(f"✓ Fast-Plaid index built in {build_secs:.3f}s")
|
||||
print(f"✓ Index saved to: {FAST_PLAID_INDEX_PATH}")
|
||||
print(f"✓ Images saved to: {Path(FAST_PLAID_INDEX_PATH) / 'images'}")
|
||||
except Exception as e:
|
||||
print(f"Error building Fast-Plaid index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
else:
|
||||
# Build original LEANN index
|
||||
try:
|
||||
retriever = _build_index(INDEX_PATH, doc_vecs, filepaths, images)
|
||||
print(f"✓ Index built and images saved to: {retriever._images_dir_path()}")
|
||||
except Exception as e:
|
||||
print(f"Error building LEANN index: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clear memory
|
||||
print(" Clearing memory...")
|
||||
del images, filepaths, doc_vecs
|
||||
|
||||
# Note: Images are now stored separately, retriever/fast_plaid_index will reference them
|
||||
|
||||
|
||||
# %%
|
||||
# Step 5: Embed query and search
|
||||
_t0 = time.perf_counter()
|
||||
q_vec = _embed_queries(model, processor, [QUERY])[0]
|
||||
query_embed_secs = time.perf_counter() - _t0
|
||||
|
||||
print(f"[Search] Method: {SEARCH_METHOD}")
|
||||
print(f"[Timing] Query embedding: {query_embed_secs:.3f}s")
|
||||
|
||||
# Run the selected search method and time it
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid search
|
||||
if fast_plaid_index is None:
|
||||
fast_plaid_index = _load_fast_plaid_index_if_exists(FAST_PLAID_INDEX_PATH)
|
||||
if fast_plaid_index is None:
|
||||
raise RuntimeError(f"Fast-Plaid index not found at {FAST_PLAID_INDEX_PATH}")
|
||||
|
||||
results, search_secs = _search_fast_plaid(fast_plaid_index, q_vec, TOPK)
|
||||
print(f"[Timing] Fast-Plaid Search: {search_secs:.3f}s")
|
||||
else:
|
||||
# Original LEANN search
|
||||
query_np = q_vec.float().numpy()
|
||||
|
||||
if SEARCH_METHOD == "ann":
|
||||
results = retriever.search(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (ANN): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact":
|
||||
results = retriever.search_exact(query_np, topk=TOPK, first_stage_k=FIRST_STAGE_K)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact rerank): {search_secs:.3f}s (first_stage_k={FIRST_STAGE_K})")
|
||||
elif SEARCH_METHOD == "exact-all":
|
||||
results = retriever.search_exact_all(query_np, topk=TOPK)
|
||||
search_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Search (Exact all): {search_secs:.3f}s")
|
||||
else:
|
||||
results = []
|
||||
if not results:
|
||||
print("No results found.")
|
||||
else:
|
||||
print(f'Top {len(results)} results for query: "{QUERY}"')
|
||||
print("\n[DEBUG] Retrieval details:")
|
||||
top_images: list[Image.Image] = []
|
||||
image_hashes = {} # Track image hashes to detect duplicates
|
||||
|
||||
for rank, (score, doc_id) in enumerate(results, start=1):
|
||||
# Retrieve image and metadata based on index type
|
||||
if USE_FAST_PLAID:
|
||||
# Fast-Plaid: load image and get metadata
|
||||
image = _get_fast_plaid_image(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not find image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = _get_fast_plaid_metadata(FAST_PLAID_INDEX_PATH, doc_id)
|
||||
path = metadata.get("filepath", f"doc_{doc_id}") if metadata else f"doc_{doc_id}"
|
||||
top_images.append(image)
|
||||
else:
|
||||
# Original LEANN: retrieve from retriever
|
||||
image = retriever.get_image(doc_id)
|
||||
if image is None:
|
||||
print(f"Warning: Could not retrieve image for doc_id {doc_id}")
|
||||
continue
|
||||
|
||||
metadata = retriever.get_metadata(doc_id)
|
||||
path = metadata.get("filepath", "unknown") if metadata else "unknown"
|
||||
top_images.append(image)
|
||||
|
||||
# Calculate image hash to detect duplicates
|
||||
import hashlib
|
||||
import io
|
||||
# Convert image to bytes for hashing
|
||||
img_bytes = io.BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
image_bytes = img_bytes.getvalue()
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()[:8]
|
||||
|
||||
# Check if this image was already seen
|
||||
duplicate_info = ""
|
||||
if image_hash in image_hashes:
|
||||
duplicate_info = f" [DUPLICATE of rank {image_hashes[image_hash]}]"
|
||||
else:
|
||||
image_hashes[image_hash] = rank
|
||||
|
||||
# Print detailed information
|
||||
print(f"{rank}) doc_id={doc_id}, MaxSim={score:.4f}, Page={path}, ImageHash={image_hash}{duplicate_info}")
|
||||
if metadata:
|
||||
print(f" Metadata: {metadata}")
|
||||
|
||||
if SAVE_TOP_IMAGE:
|
||||
from pathlib import Path as _Path
|
||||
|
||||
base = _Path(SAVE_TOP_IMAGE)
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if base.suffix:
|
||||
out_path = base.parent / f"{base.stem}_rank{rank}{base.suffix}"
|
||||
else:
|
||||
out_path = base / f"retrieved_page_rank{rank}.png"
|
||||
img.save(str(out_path))
|
||||
# Print the retrieval score (document-level MaxSim) alongside the saved path
|
||||
try:
|
||||
score, _doc_id = results[rank - 1]
|
||||
print(f"Saved retrieved page (rank {rank}) [MaxSim={score:.4f}] to: {out_path}")
|
||||
except Exception:
|
||||
print(f"Saved retrieved page (rank {rank}) to: {out_path}")
|
||||
|
||||
|
||||
# %%
|
||||
# Step 6: Similarity maps for top-K results
|
||||
if results and SIMILARITY_MAP:
|
||||
token_idx = None if SIM_TOKEN_IDX < 0 else int(SIM_TOKEN_IDX)
|
||||
from pathlib import Path as _Path
|
||||
|
||||
output_base = _Path(SIM_OUTPUT) if SIM_OUTPUT else None
|
||||
for rank, img in enumerate(top_images[:TOPK], start=1):
|
||||
if output_base:
|
||||
if output_base.suffix:
|
||||
out_dir = output_base.parent
|
||||
out_name = f"{output_base.stem}_rank{rank}{output_base.suffix}"
|
||||
out_path = str(out_dir / out_name)
|
||||
else:
|
||||
out_dir = output_base
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = str(out_dir / f"similarity_map_rank{rank}.png")
|
||||
else:
|
||||
out_path = None
|
||||
chosen_idx, max_sim = _generate_similarity_map(
|
||||
model=model,
|
||||
processor=processor,
|
||||
image=img,
|
||||
query=QUERY,
|
||||
token_idx=token_idx,
|
||||
output_path=out_path,
|
||||
)
|
||||
if out_path:
|
||||
print(
|
||||
f"Saved similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f}) to: {out_path}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Computed similarity map for rank {rank}, token #{chosen_idx} (max={max_sim:.2f})"
|
||||
)
|
||||
|
||||
|
||||
# %%
|
||||
# Step 7: Optional answer generation
|
||||
if results and ANSWER:
|
||||
qwen = QwenVL(device=device_str)
|
||||
_t0 = time.perf_counter()
|
||||
response = qwen.answer(QUERY, top_images[:TOPK], max_new_tokens=MAX_NEW_TOKENS)
|
||||
gen_secs = time.perf_counter() - _t0
|
||||
print(f"[Timing] Generation: {gen_secs:.3f}s")
|
||||
print("\nAnswer:")
|
||||
print(response)
|
||||
@@ -0,0 +1,451 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v1 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v1 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v1 tasks
|
||||
python vidore_v1_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v1_benchmark.py --model colqwen2 --task VidoreArxivQARetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v1_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# ViDoRe v1 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V1_TASKS = {
|
||||
"VidoreArxivQARetrieval": {
|
||||
"dataset_path": "vidore/arxivqa_test_subsampled_beir",
|
||||
"revision": "7d94d570960eac2408d3baa7a33f9de4822ae3e4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreDocVQARetrieval": {
|
||||
"dataset_path": "vidore/docvqa_test_subsampled_beir",
|
||||
"revision": "162ba2fc1a8437eda8b6c37b240bc1c0f0deb092",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreInfoVQARetrieval": {
|
||||
"dataset_path": "vidore/infovqa_test_subsampled_beir",
|
||||
"revision": "b802cc5fd6c605df2d673a963667d74881d2c9a4",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTabfquadRetrieval": {
|
||||
"dataset_path": "vidore/tabfquad_test_subsampled_beir",
|
||||
"revision": "61a2224bcd29b7b261a4892ff4c8bea353527a31",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreTatdqaRetrieval": {
|
||||
"dataset_path": "vidore/tatdqa_test_beir",
|
||||
"revision": "5feb5630fdff4d8d189ffedb2dba56862fdd45c0",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreShiftProjectRetrieval": {
|
||||
"dataset_path": "vidore/shiftproject_test_beir",
|
||||
"revision": "84a382e05c4473fed9cff2bbae95fe2379416117",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAAIRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_artificial_intelligence_test_beir",
|
||||
"revision": "2d9ebea5a1c6e9ef4a3b902a612f605dca11261c",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAEnergyRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_energy_test_beir",
|
||||
"revision": "9935aadbad5c8deec30910489db1b2c7133ae7a7",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAGovernmentReportsRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_government_reports_test_beir",
|
||||
"revision": "b4909afa930f81282fd20601e860668073ad02aa",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"VidoreSyntheticDocQAHealthcareIndustryRetrieval": {
|
||||
"dataset_path": "vidore/syntheticDocQA_healthcare_industry_test_beir",
|
||||
"revision": "f9e25d5b6e13e1ad9f5c3cce202565031b3ab164",
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
# Task name aliases (short names -> full names)
|
||||
TASK_ALIASES = {
|
||||
"arxivqa": "VidoreArxivQARetrieval",
|
||||
"docvqa": "VidoreDocVQARetrieval",
|
||||
"infovqa": "VidoreInfoVQARetrieval",
|
||||
"tabfquad": "VidoreTabfquadRetrieval",
|
||||
"tatdqa": "VidoreTatdqaRetrieval",
|
||||
"shiftproject": "VidoreShiftProjectRetrieval",
|
||||
"syntheticdocqa_ai": "VidoreSyntheticDocQAAIRetrieval",
|
||||
"syntheticdocqa_energy": "VidoreSyntheticDocQAEnergyRetrieval",
|
||||
"syntheticdocqa_government": "VidoreSyntheticDocQAGovernmentReportsRetrieval",
|
||||
"syntheticdocqa_healthcare": "VidoreSyntheticDocQAHealthcareIndustryRetrieval",
|
||||
}
|
||||
|
||||
|
||||
def normalize_task_name(task_name: str) -> str:
|
||||
"""Normalize task name (handle aliases)."""
|
||||
task_name_lower = task_name.lower()
|
||||
if task_name in VIDORE_V1_TASKS:
|
||||
return task_name
|
||||
if task_name_lower in TASK_ALIASES:
|
||||
return TASK_ALIASES[task_name_lower]
|
||||
# Try partial match
|
||||
for alias, full_name in TASK_ALIASES.items():
|
||||
if alias in task_name_lower or task_name_lower in alias:
|
||||
return full_name
|
||||
return task_name
|
||||
|
||||
|
||||
def get_safe_model_name(model_name: str) -> str:
|
||||
"""Get a safe model name for use in file paths."""
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
# If it's a path, use basename or hash
|
||||
if os.path.exists(model_name) and os.path.isdir(model_name):
|
||||
# Use basename if it's reasonable, otherwise use hash
|
||||
basename = os.path.basename(model_name.rstrip("/"))
|
||||
if basename and len(basename) < 100 and not basename.startswith("."):
|
||||
return basename
|
||||
# Use hash for very long or problematic paths
|
||||
return hashlib.md5(model_name.encode()).hexdigest()[:16]
|
||||
# For HuggingFace model names, replace / with _
|
||||
return model_name.replace("/", "_").replace(":", "_")
|
||||
|
||||
|
||||
def load_vidore_v1_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v1 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split})")
|
||||
|
||||
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
|
||||
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
|
||||
|
||||
queries: dict[str, str] = {}
|
||||
for row in query_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
queries[query_id] = row_dict["query"]
|
||||
|
||||
# Load corpus (images) - cast to Dataset
|
||||
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
|
||||
|
||||
corpus: dict[str, Any] = {}
|
||||
for row in corpus_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["image"]
|
||||
elif "page_image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments) - cast to Dataset
|
||||
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
|
||||
|
||||
qrels: dict[str, dict[str, int]] = {}
|
||||
for row in qrels_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row_dict["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 1000,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v1 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Normalize task name (handle aliases)
|
||||
task_name = normalize_task_name(task_name)
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V1_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V1_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V1_TASKS[task_name]
|
||||
dataset_path = str(task_config["dataset_path"])
|
||||
revision = str(task_config["revision"])
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v1_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
)
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 20, 100, 1000]
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(f"\nWarning: No queries found for task {task_name}. Skipping evaluation.")
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
# Use safe model name for index path (different models need different indexes)
|
||||
safe_model_name = get_safe_model_name(model_name)
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{safe_model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v1 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
help="Model to use: 'colqwen2', 'colpali', or path to a model directory (supports LoRA adapters)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Top-k results to retrieve (MTEB default is max(k_values)=1000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,20,100,1000",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v1_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [normalize_task_name(args.task)]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V1_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [normalize_task_name(t.strip()) for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,443 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Modular script to reproduce NDCG results for ViDoRe v2 benchmark.
|
||||
|
||||
This script uses the interface from leann_multi_vector.py to:
|
||||
1. Download ViDoRe v2 datasets
|
||||
2. Build indexes (LEANN or Fast-Plaid)
|
||||
3. Perform retrieval
|
||||
4. Evaluate using NDCG metrics
|
||||
|
||||
Usage:
|
||||
# Evaluate all ViDoRe v2 tasks
|
||||
python vidore_v2_benchmark.py --model colqwen2 --tasks all
|
||||
|
||||
# Evaluate specific task
|
||||
python vidore_v2_benchmark.py --model colqwen2 --task Vidore2ESGReportsRetrieval
|
||||
|
||||
# Use Fast-Plaid index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --use-fast-plaid --fast-plaid-index-path ./indexes/vidore_fastplaid
|
||||
|
||||
# Rebuild index
|
||||
python vidore_v2_benchmark.py --model colqwen2 --rebuild-index
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from leann_multi_vector import (
|
||||
ViDoReBenchmarkEvaluator,
|
||||
_ensure_repo_paths_importable,
|
||||
)
|
||||
|
||||
_ensure_repo_paths_importable(__file__)
|
||||
|
||||
# Language name to dataset language field value mapping
|
||||
# Dataset uses ISO 639-3 + ISO 15924 format (e.g., "eng-Latn")
|
||||
LANGUAGE_MAPPING = {
|
||||
"english": "eng-Latn",
|
||||
"french": "fra-Latn",
|
||||
"spanish": "spa-Latn",
|
||||
"german": "deu-Latn",
|
||||
}
|
||||
|
||||
# ViDoRe v2 task configurations
|
||||
# Prompts match MTEB task metadata prompts
|
||||
VIDORE_V2_TASKS = {
|
||||
"Vidore2ESGReportsRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_v2",
|
||||
"revision": "0542c0d03da0ec1c8cbc517c8d78e7e95c75d3d3",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2EconomicsReportsRetrieval": {
|
||||
"dataset_path": "vidore/economics_reports_v2",
|
||||
"revision": "b3e3a04b07fbbaffe79be49dabf92f691fbca252",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2BioMedicalLecturesRetrieval": {
|
||||
"dataset_path": "vidore/biomedical_lectures_v2",
|
||||
"revision": "a29202f0da409034d651614d87cd8938d254e2ea",
|
||||
"languages": ["french", "spanish", "english", "german"],
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
"Vidore2ESGReportsHLRetrieval": {
|
||||
"dataset_path": "vidore/esg_reports_human_labeled_v2",
|
||||
"revision": "6d467dedb09a75144ede1421747e47cf036857dd",
|
||||
# Note: This dataset doesn't have language filtering - all queries are English
|
||||
"languages": None, # No language filtering needed
|
||||
"prompt": {"query": "Find a screenshot that relevant to the user's question."},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_vidore_v2_data(
|
||||
dataset_path: str,
|
||||
revision: Optional[str] = None,
|
||||
split: str = "test",
|
||||
language: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load ViDoRe v2 dataset.
|
||||
|
||||
Returns:
|
||||
corpus: dict mapping corpus_id to PIL Image
|
||||
queries: dict mapping query_id to query text
|
||||
qrels: dict mapping query_id to dict of {corpus_id: relevance_score}
|
||||
"""
|
||||
print(f"Loading dataset: {dataset_path} (split={split}, language={language})")
|
||||
|
||||
# Load queries - cast to Dataset since we know split returns Dataset not DatasetDict
|
||||
query_ds = cast(Dataset, load_dataset(dataset_path, "queries", split=split, revision=revision))
|
||||
|
||||
# Check if dataset has language field before filtering
|
||||
has_language_field = len(query_ds) > 0 and "language" in query_ds.column_names
|
||||
|
||||
if language and has_language_field:
|
||||
# Map language name to dataset language field value (e.g., "english" -> "eng-Latn")
|
||||
dataset_language = LANGUAGE_MAPPING.get(language, language)
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == dataset_language)
|
||||
# Check if filtering resulted in empty dataset
|
||||
if len(query_ds_filtered) == 0:
|
||||
print(
|
||||
f"Warning: No queries found after filtering by language '{language}' (mapped to '{dataset_language}')."
|
||||
)
|
||||
# Try with original language value (dataset might use simple names like 'english')
|
||||
print(f"Trying with original language value '{language}'...")
|
||||
query_ds_filtered = query_ds.filter(lambda x: x.get("language") == language)
|
||||
if len(query_ds_filtered) == 0:
|
||||
# Try to get a sample to see actual language values
|
||||
try:
|
||||
sample_ds = cast(
|
||||
Dataset,
|
||||
load_dataset(dataset_path, "queries", split=split, revision=revision),
|
||||
)
|
||||
if len(sample_ds) > 0 and "language" in sample_ds.column_names:
|
||||
sample_langs = set(sample_ds["language"])
|
||||
print(f"Available language values in dataset: {sample_langs}")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"Found {len(query_ds_filtered)} queries using original language value '{language}'"
|
||||
)
|
||||
query_ds = query_ds_filtered
|
||||
|
||||
queries: dict[str, str] = {}
|
||||
for row in query_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
queries[query_id] = row_dict["query"]
|
||||
|
||||
# Load corpus (images) - cast to Dataset
|
||||
corpus_ds = cast(Dataset, load_dataset(dataset_path, "corpus", split=split, revision=revision))
|
||||
|
||||
corpus: dict[str, Any] = {}
|
||||
for row in corpus_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
# Extract image from the dataset row
|
||||
if "image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["image"]
|
||||
elif "page_image" in row_dict:
|
||||
corpus[corpus_id] = row_dict["page_image"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No image field found in corpus. Available fields: {list(row_dict.keys())}"
|
||||
)
|
||||
|
||||
# Load qrels (relevance judgments) - cast to Dataset
|
||||
qrels_ds = cast(Dataset, load_dataset(dataset_path, "qrels", split=split, revision=revision))
|
||||
|
||||
qrels: dict[str, dict[str, int]] = {}
|
||||
for row in qrels_ds:
|
||||
row_dict = cast(dict[str, Any], row)
|
||||
query_id = f"query-{split}-{row_dict['query-id']}"
|
||||
corpus_id = f"corpus-{split}-{row_dict['corpus-id']}"
|
||||
if query_id not in qrels:
|
||||
qrels[query_id] = {}
|
||||
qrels[query_id][corpus_id] = int(row_dict["score"])
|
||||
|
||||
print(
|
||||
f"Loaded {len(queries)} queries, {len(corpus)} corpus items, {len(qrels)} query-relevance mappings"
|
||||
)
|
||||
|
||||
# Filter qrels to only include queries that exist
|
||||
qrels = {qid: rel_docs for qid, rel_docs in qrels.items() if qid in queries}
|
||||
|
||||
# Filter out queries without any relevant documents (matching MTEB behavior)
|
||||
# This is important for correct NDCG calculation
|
||||
qrels_filtered = {qid: rel_docs for qid, rel_docs in qrels.items() if len(rel_docs) > 0}
|
||||
queries_filtered = {
|
||||
qid: query_text for qid, query_text in queries.items() if qid in qrels_filtered
|
||||
}
|
||||
|
||||
print(
|
||||
f"After filtering queries without positives: {len(queries_filtered)} queries, {len(qrels_filtered)} query-relevance mappings"
|
||||
)
|
||||
|
||||
return corpus, queries_filtered, qrels_filtered
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
task_name: str,
|
||||
model_name: str,
|
||||
index_path: str,
|
||||
use_fast_plaid: bool = False,
|
||||
fast_plaid_index_path: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
rebuild_index: bool = False,
|
||||
top_k: int = 100,
|
||||
first_stage_k: int = 500,
|
||||
k_values: Optional[list[int]] = None,
|
||||
output_dir: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Evaluate a single ViDoRe v2 task.
|
||||
"""
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Evaluating task: {task_name}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Get task config
|
||||
if task_name not in VIDORE_V2_TASKS:
|
||||
raise ValueError(f"Unknown task: {task_name}. Available: {list(VIDORE_V2_TASKS.keys())}")
|
||||
|
||||
task_config = VIDORE_V2_TASKS[task_name]
|
||||
dataset_path = str(task_config["dataset_path"])
|
||||
revision = str(task_config["revision"])
|
||||
|
||||
# Determine language
|
||||
if language is None:
|
||||
# Use first language if multiple available
|
||||
languages = cast(Optional[list[str]], task_config.get("languages"))
|
||||
if languages is None:
|
||||
# Task doesn't support language filtering (e.g., Vidore2ESGReportsHLRetrieval)
|
||||
language = None
|
||||
elif len(languages) == 1:
|
||||
language = languages[0]
|
||||
else:
|
||||
language = None
|
||||
|
||||
# Initialize k_values if not provided
|
||||
if k_values is None:
|
||||
k_values = [1, 3, 5, 10, 100]
|
||||
|
||||
# Load data
|
||||
corpus, queries, qrels = load_vidore_v2_data(
|
||||
dataset_path=dataset_path,
|
||||
revision=revision,
|
||||
split="test",
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Check if we have any queries
|
||||
if len(queries) == 0:
|
||||
print(
|
||||
f"\nWarning: No queries found for task {task_name} with language {language}. Skipping evaluation."
|
||||
)
|
||||
# Return zero scores
|
||||
scores = {}
|
||||
for k in k_values:
|
||||
scores[f"ndcg_at_{k}"] = 0.0
|
||||
scores[f"map_at_{k}"] = 0.0
|
||||
scores[f"recall_at_{k}"] = 0.0
|
||||
scores[f"precision_at_{k}"] = 0.0
|
||||
scores[f"mrr_at_{k}"] = 0.0
|
||||
return scores
|
||||
|
||||
# Initialize evaluator
|
||||
evaluator = ViDoReBenchmarkEvaluator(
|
||||
model_name=model_name,
|
||||
use_fast_plaid=use_fast_plaid,
|
||||
top_k=top_k,
|
||||
first_stage_k=first_stage_k,
|
||||
k_values=k_values,
|
||||
)
|
||||
|
||||
# Build or load index
|
||||
index_path_full = index_path if not use_fast_plaid else fast_plaid_index_path
|
||||
if index_path_full is None:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}"
|
||||
if use_fast_plaid:
|
||||
index_path_full = f"./indexes/{task_name}_{model_name}_fastplaid"
|
||||
|
||||
index_or_retriever, corpus_ids_ordered = evaluator.build_index_from_corpus(
|
||||
corpus=corpus,
|
||||
index_path=index_path_full,
|
||||
rebuild=rebuild_index,
|
||||
)
|
||||
|
||||
# Search queries
|
||||
task_prompt = cast(Optional[dict[str, str]], task_config.get("prompt"))
|
||||
results = evaluator.search_queries(
|
||||
queries=queries,
|
||||
corpus_ids=corpus_ids_ordered,
|
||||
index_or_retriever=index_or_retriever,
|
||||
fast_plaid_index_path=fast_plaid_index_path,
|
||||
task_prompt=task_prompt,
|
||||
)
|
||||
|
||||
# Evaluate
|
||||
scores = evaluator.evaluate_results(results, qrels, k_values=k_values)
|
||||
|
||||
# Print results
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Results for {task_name}:")
|
||||
print(f"{'=' * 80}")
|
||||
for metric, value in scores.items():
|
||||
if isinstance(value, (int, float)):
|
||||
print(f" {metric}: {value:.5f}")
|
||||
|
||||
# Save results
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
results_file = os.path.join(output_dir, f"{task_name}_results.json")
|
||||
scores_file = os.path.join(output_dir, f"{task_name}_scores.json")
|
||||
|
||||
with open(results_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved results to: {results_file}")
|
||||
|
||||
with open(scores_file, "w") as f:
|
||||
json.dump(scores, f, indent=2)
|
||||
print(f"Saved scores to: {scores_file}")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate ViDoRe v2 benchmark using LEANN/Fast-Plaid indexing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="colqwen2",
|
||||
choices=["colqwen2", "colpali"],
|
||||
help="Model to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific task to evaluate (or 'all' for all tasks)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="all",
|
||||
help="Tasks to evaluate: 'all' or comma-separated list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to LEANN index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-fast-plaid",
|
||||
action="store_true",
|
||||
help="Use Fast-Plaid instead of LEANN",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-plaid-index-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Fast-Plaid index (auto-generated if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rebuild-index",
|
||||
action="store_true",
|
||||
help="Rebuild index even if it exists",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Language to evaluate (default: first available)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Top-k results to retrieve",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--first-stage-k",
|
||||
type=int,
|
||||
default=500,
|
||||
help="First stage k for LEANN search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--k-values",
|
||||
type=str,
|
||||
default="1,3,5,10,100",
|
||||
help="Comma-separated k values for evaluation (e.g., '1,3,5,10,100')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./vidore_v2_results",
|
||||
help="Output directory for results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse k_values
|
||||
k_values = [int(k.strip()) for k in args.k_values.split(",")]
|
||||
|
||||
# Determine tasks to evaluate
|
||||
if args.task:
|
||||
tasks_to_eval = [args.task]
|
||||
elif args.tasks.lower() == "all":
|
||||
tasks_to_eval = list(VIDORE_V2_TASKS.keys())
|
||||
else:
|
||||
tasks_to_eval = [t.strip() for t in args.tasks.split(",")]
|
||||
|
||||
print(f"Tasks to evaluate: {tasks_to_eval}")
|
||||
|
||||
# Evaluate each task
|
||||
all_scores = {}
|
||||
for task_name in tasks_to_eval:
|
||||
try:
|
||||
scores = evaluate_task(
|
||||
task_name=task_name,
|
||||
model_name=args.model,
|
||||
index_path=args.index_path,
|
||||
use_fast_plaid=args.use_fast_plaid,
|
||||
fast_plaid_index_path=args.fast_plaid_index_path,
|
||||
language=args.language,
|
||||
rebuild_index=args.rebuild_index,
|
||||
top_k=args.top_k,
|
||||
first_stage_k=args.first_stage_k,
|
||||
k_values=k_values,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
all_scores[task_name] = scores
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating {task_name}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
# Print summary
|
||||
if all_scores:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
for task_name, scores in all_scores.items():
|
||||
print(f"\n{task_name}:")
|
||||
# Print main metrics
|
||||
for metric in ["ndcg_at_5", "ndcg_at_10", "ndcg_at_100", "map_at_10", "recall_at_10"]:
|
||||
if metric in scores:
|
||||
print(f" {metric}: {scores[metric]:.5f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
183
apps/semantic_file_search/leann-plus-temporal-search.py
Normal file
183
apps/semantic_file_search/leann-plus-temporal-search.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
|
||||
|
||||
class TimeParser:
|
||||
def __init__(self):
|
||||
# Main pattern: captures optional fuzzy modifier, number, unit, and optional "ago"
|
||||
self.pattern = r"(?:(around|about|roughly|approximately)\s+)?(\d+)\s+(hour|day|week|month|year)s?(?:\s+ago)?"
|
||||
|
||||
# Compile for performance
|
||||
self.regex = re.compile(self.pattern, re.IGNORECASE)
|
||||
|
||||
# Stop words to remove before regex parsing
|
||||
self.stop_words = {
|
||||
"in",
|
||||
"at",
|
||||
"of",
|
||||
"by",
|
||||
"as",
|
||||
"me",
|
||||
"the",
|
||||
"a",
|
||||
"an",
|
||||
"and",
|
||||
"any",
|
||||
"find",
|
||||
"search",
|
||||
"list",
|
||||
"ago",
|
||||
"back",
|
||||
"past",
|
||||
"earlier",
|
||||
}
|
||||
|
||||
def clean_text(self, text):
|
||||
"""Remove stop words from text"""
|
||||
words = text.split()
|
||||
cleaned = " ".join(word for word in words if word.lower() not in self.stop_words)
|
||||
return cleaned
|
||||
|
||||
def parse(self, text):
|
||||
"""Extract all time expressions from text"""
|
||||
# Clean text first
|
||||
cleaned_text = self.clean_text(text)
|
||||
|
||||
matches = []
|
||||
for match in self.regex.finditer(cleaned_text):
|
||||
fuzzy = match.group(1) # "around", "about", etc.
|
||||
number = int(match.group(2))
|
||||
unit = match.group(3).lower()
|
||||
|
||||
matches.append(
|
||||
{
|
||||
"full_match": match.group(0),
|
||||
"fuzzy": bool(fuzzy),
|
||||
"number": number,
|
||||
"unit": unit,
|
||||
"range": self.calculate_range(number, unit, bool(fuzzy)),
|
||||
}
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
def calculate_range(self, number, unit, is_fuzzy):
|
||||
"""Convert to actual datetime range and return ISO format strings"""
|
||||
units = {
|
||||
"hour": timedelta(hours=number),
|
||||
"day": timedelta(days=number),
|
||||
"week": timedelta(weeks=number),
|
||||
"month": timedelta(days=number * 30),
|
||||
"year": timedelta(days=number * 365),
|
||||
}
|
||||
|
||||
delta = units[unit]
|
||||
now = datetime.now()
|
||||
target = now - delta
|
||||
|
||||
if is_fuzzy:
|
||||
buffer = delta * 0.2 # 20% buffer for fuzzy
|
||||
start = (target - buffer).isoformat()
|
||||
end = (target + buffer).isoformat()
|
||||
else:
|
||||
start = target.isoformat()
|
||||
end = now.isoformat()
|
||||
|
||||
return (start, end)
|
||||
|
||||
|
||||
def search_files(query, top_k=15):
|
||||
"""Search the index and return results"""
|
||||
# Parse time expressions
|
||||
parser = TimeParser()
|
||||
time_matches = parser.parse(query)
|
||||
|
||||
# Remove time expressions from query for semantic search
|
||||
clean_query = query
|
||||
if time_matches:
|
||||
for match in time_matches:
|
||||
clean_query = clean_query.replace(match["full_match"], "").strip()
|
||||
|
||||
# Check if clean_query is less than 4 characters
|
||||
if len(clean_query) < 4:
|
||||
print("Error: add more input for accurate results.")
|
||||
return
|
||||
|
||||
# Single query to vector DB
|
||||
searcher = LeannSearcher(INDEX_PATH)
|
||||
results = searcher.search(
|
||||
clean_query if clean_query else query, top_k=top_k, recompute_embeddings=False
|
||||
)
|
||||
|
||||
# Filter by time if time expression found
|
||||
if time_matches:
|
||||
time_range = time_matches[0]["range"] # Use first time expression
|
||||
start_time, end_time = time_range
|
||||
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
# Access metadata attribute directly (not .get())
|
||||
metadata = result.metadata if hasattr(result, "metadata") else {}
|
||||
|
||||
if metadata:
|
||||
# Check modification date first, fall back to creation date
|
||||
date_str = metadata.get("modification_date") or metadata.get("creation_date")
|
||||
|
||||
if date_str:
|
||||
# Convert strings to datetime objects for proper comparison
|
||||
try:
|
||||
file_date = datetime.fromisoformat(date_str)
|
||||
start_dt = datetime.fromisoformat(start_time)
|
||||
end_dt = datetime.fromisoformat(end_time)
|
||||
|
||||
# Compare dates properly
|
||||
if start_dt <= file_date <= end_dt:
|
||||
filtered_results.append(result)
|
||||
except (ValueError, TypeError):
|
||||
# Handle invalid date formats
|
||||
print(f"Warning: Invalid date format in metadata: {date_str}")
|
||||
continue
|
||||
|
||||
results = filtered_results
|
||||
|
||||
# Print results
|
||||
print(f"\nSearch results for: '{query}'")
|
||||
if time_matches:
|
||||
print(
|
||||
f"Time filter: {time_matches[0]['number']} {time_matches[0]['unit']}(s) {'(fuzzy)' if time_matches[0]['fuzzy'] else ''}"
|
||||
)
|
||||
print(
|
||||
f"Date range: {time_matches[0]['range'][0][:10]} to {time_matches[0]['range'][1][:10]}"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
print(f"\n[{i}] Score: {result.score:.4f}")
|
||||
print(f"Content: {result.text}")
|
||||
|
||||
# Show metadata if present
|
||||
metadata = result.metadata if hasattr(result, "metadata") else None
|
||||
if metadata:
|
||||
if "creation_date" in metadata:
|
||||
print(f"Created: {metadata['creation_date']}")
|
||||
if "modification_date" in metadata:
|
||||
print(f"Modified: {metadata['modification_date']}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print('Usage: python search_index.py "<search query>" [top_k]')
|
||||
sys.exit(1)
|
||||
|
||||
query = sys.argv[1]
|
||||
top_k = int(sys.argv[2]) if len(sys.argv) > 2 else 15
|
||||
|
||||
search_files(query, top_k)
|
||||
82
apps/semantic_file_search/leann_index_builder.py
Normal file
82
apps/semantic_file_search/leann_index_builder.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
def process_json_items(json_file_path):
|
||||
"""Load and process JSON file with metadata items"""
|
||||
|
||||
with open(json_file_path, encoding="utf-8") as f:
|
||||
items = json.load(f)
|
||||
|
||||
# Guard against empty JSON
|
||||
if not items:
|
||||
print("⚠️ No items found in the JSON file. Exiting gracefully.")
|
||||
return
|
||||
|
||||
INDEX_PATH = str(Path("./").resolve() / "demo.leann")
|
||||
builder = LeannBuilder(backend_name="hnsw", is_recompute=False)
|
||||
|
||||
total_items = len(items)
|
||||
items_added = 0
|
||||
print(f"Processing {total_items} items...")
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
try:
|
||||
# Create embedding text sentence
|
||||
embedding_text = f"{item.get('Name', 'unknown')} located at {item.get('Path', 'unknown')} and size {item.get('Size', 'unknown')} bytes with content type {item.get('ContentType', 'unknown')} and kind {item.get('Kind', 'unknown')}"
|
||||
|
||||
# Prepare metadata with dates
|
||||
metadata = {}
|
||||
if "CreationDate" in item:
|
||||
metadata["creation_date"] = item["CreationDate"]
|
||||
if "ContentChangeDate" in item:
|
||||
metadata["modification_date"] = item["ContentChangeDate"]
|
||||
|
||||
# Add to builder
|
||||
builder.add_text(embedding_text, metadata=metadata)
|
||||
items_added += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ Warning: Failed to process item {idx}: {e}")
|
||||
continue
|
||||
|
||||
# Show progress
|
||||
progress = (idx + 1) / total_items * 100
|
||||
sys.stdout.write(f"\rProgress: {idx + 1}/{total_items} ({progress:.1f}%)")
|
||||
sys.stdout.flush()
|
||||
|
||||
print() # New line after progress
|
||||
|
||||
# Guard against no successfully added items
|
||||
if items_added == 0:
|
||||
print("⚠️ No items were successfully added to the index. Exiting gracefully.")
|
||||
return
|
||||
|
||||
print(f"\n✅ Successfully processed {items_added}/{total_items} items")
|
||||
print("Building index...")
|
||||
|
||||
try:
|
||||
builder.build_index(INDEX_PATH)
|
||||
print(f"✓ Index saved to {INDEX_PATH}")
|
||||
except ValueError as e:
|
||||
if "No chunks added" in str(e):
|
||||
print("⚠️ No chunks were added to the builder. Index not created.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python build_index.py <json_file>")
|
||||
sys.exit(1)
|
||||
|
||||
json_file = sys.argv[1]
|
||||
if not Path(json_file).exists():
|
||||
print(f"Error: File {json_file} not found")
|
||||
sys.exit(1)
|
||||
|
||||
process_json_items(json_file)
|
||||
265
apps/semantic_file_search/spotlight_index_dump.py
Normal file
265
apps/semantic_file_search/spotlight_index_dump.py
Normal file
@@ -0,0 +1,265 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Spotlight Metadata Dumper for Vector DB
|
||||
Extracts only essential metadata for semantic search embeddings
|
||||
Output is optimized for vector database storage with minimal fields
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# Check platform before importing macOS-specific modules
|
||||
if sys.platform != "darwin":
|
||||
print("This script requires macOS (uses Spotlight)")
|
||||
sys.exit(1)
|
||||
|
||||
from Foundation import NSDate, NSMetadataQuery, NSPredicate, NSRunLoop
|
||||
|
||||
# EDIT THIS LIST: Add or remove folders to search
|
||||
# Can be either:
|
||||
# - Folder names relative to home directory (e.g., "Desktop", "Downloads")
|
||||
# - Absolute paths (e.g., "/Applications", "/System/Library")
|
||||
SEARCH_FOLDERS = [
|
||||
"Desktop",
|
||||
"Downloads",
|
||||
"Documents",
|
||||
"Music",
|
||||
"Pictures",
|
||||
"Movies",
|
||||
# "Library", # Uncomment to include
|
||||
# "/Applications", # Absolute path example
|
||||
# "Code/Projects", # Subfolder example
|
||||
# Add any other folders here
|
||||
]
|
||||
|
||||
|
||||
def convert_to_serializable(obj):
|
||||
"""Convert NS objects to Python serializable types"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
# Handle NSDate
|
||||
if hasattr(obj, "timeIntervalSince1970"):
|
||||
return datetime.fromtimestamp(obj.timeIntervalSince1970()).isoformat()
|
||||
|
||||
# Handle NSArray
|
||||
if hasattr(obj, "count") and hasattr(obj, "objectAtIndex_"):
|
||||
return [convert_to_serializable(obj.objectAtIndex_(i)) for i in range(obj.count())]
|
||||
|
||||
# Convert to string
|
||||
try:
|
||||
return str(obj)
|
||||
except Exception:
|
||||
return repr(obj)
|
||||
|
||||
|
||||
def dump_spotlight_data(max_items=10, output_file="spotlight_dump.json"):
|
||||
"""
|
||||
Dump Spotlight data using public.item predicate
|
||||
"""
|
||||
# Build full paths from SEARCH_FOLDERS
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
search_paths = []
|
||||
|
||||
print("Search locations:")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Check if it's an absolute path or relative
|
||||
if folder.startswith("/"):
|
||||
full_path = folder
|
||||
else:
|
||||
full_path = os.path.join(home_dir, folder)
|
||||
|
||||
if os.path.exists(full_path):
|
||||
search_paths.append(full_path)
|
||||
print(f" ✓ {full_path}")
|
||||
else:
|
||||
print(f" ✗ {full_path} (not found)")
|
||||
|
||||
if not search_paths:
|
||||
print("No valid search paths found!")
|
||||
return []
|
||||
|
||||
print(f"\nDumping {max_items} items from Spotlight (public.item)...")
|
||||
|
||||
# Create query with public.item predicate
|
||||
query = NSMetadataQuery.alloc().init()
|
||||
predicate = NSPredicate.predicateWithFormat_("kMDItemContentTypeTree CONTAINS 'public.item'")
|
||||
query.setPredicate_(predicate)
|
||||
|
||||
# Set search scopes to our specific folders
|
||||
query.setSearchScopes_(search_paths)
|
||||
|
||||
print("Starting query...")
|
||||
query.startQuery()
|
||||
|
||||
# Wait for gathering to complete
|
||||
run_loop = NSRunLoop.currentRunLoop()
|
||||
print("Gathering results...")
|
||||
|
||||
# Let it gather for a few seconds
|
||||
for i in range(50): # 5 seconds max
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
# Check gathering status periodically
|
||||
if i % 10 == 0:
|
||||
current_count = query.resultCount()
|
||||
if current_count > 0:
|
||||
print(f" Found {current_count} items so far...")
|
||||
|
||||
# Continue while still gathering (up to 2 more seconds)
|
||||
timeout = NSDate.dateWithTimeIntervalSinceNow_(2.0)
|
||||
while query.isGathering() and timeout.timeIntervalSinceNow() > 0:
|
||||
run_loop.runMode_beforeDate_(
|
||||
"NSDefaultRunLoopMode", NSDate.dateWithTimeIntervalSinceNow_(0.1)
|
||||
)
|
||||
|
||||
query.stopQuery()
|
||||
|
||||
total_results = query.resultCount()
|
||||
print(f"Found {total_results} total items")
|
||||
|
||||
if total_results == 0:
|
||||
print("No results found")
|
||||
return []
|
||||
|
||||
# Process items
|
||||
items_to_process = min(total_results, max_items)
|
||||
results = []
|
||||
|
||||
# ONLY relevant attributes for vector embeddings
|
||||
# These provide essential context for semantic search without bloat
|
||||
attributes = [
|
||||
"kMDItemPath", # Full path for file retrieval
|
||||
"kMDItemFSName", # Filename for display & embedding
|
||||
"kMDItemFSSize", # Size for filtering/ranking
|
||||
"kMDItemContentType", # File type for categorization
|
||||
"kMDItemKind", # Human-readable type for embedding
|
||||
"kMDItemFSCreationDate", # Temporal context
|
||||
"kMDItemFSContentChangeDate", # Recency for ranking
|
||||
]
|
||||
|
||||
print(f"Processing {items_to_process} items...")
|
||||
|
||||
for i in range(items_to_process):
|
||||
try:
|
||||
item = query.resultAtIndex_(i)
|
||||
metadata = {}
|
||||
|
||||
# Extract ONLY the relevant attributes
|
||||
for attr in attributes:
|
||||
try:
|
||||
value = item.valueForAttribute_(attr)
|
||||
if value is not None:
|
||||
# Keep the attribute name clean (remove kMDItem prefix for cleaner JSON)
|
||||
clean_key = attr.replace("kMDItem", "").replace("FS", "")
|
||||
metadata[clean_key] = convert_to_serializable(value)
|
||||
except (AttributeError, ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Only add if we have at least a path
|
||||
if metadata.get("Path"):
|
||||
results.append(metadata)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing item {i}: {e}")
|
||||
continue
|
||||
|
||||
# Save to JSON
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\n✓ Saved {len(results)} items to {output_file}")
|
||||
|
||||
# Show summary
|
||||
print("\nSample items:")
|
||||
import os
|
||||
|
||||
home_dir = os.path.expanduser("~")
|
||||
|
||||
for i, item in enumerate(results[:3]):
|
||||
print(f"\n[Item {i + 1}]")
|
||||
print(f" Path: {item.get('Path', 'N/A')}")
|
||||
print(f" Name: {item.get('Name', 'N/A')}")
|
||||
print(f" Type: {item.get('ContentType', 'N/A')}")
|
||||
print(f" Kind: {item.get('Kind', 'N/A')}")
|
||||
|
||||
# Handle size properly
|
||||
size = item.get("Size")
|
||||
if size:
|
||||
try:
|
||||
size_int = int(size)
|
||||
if size_int > 1024 * 1024:
|
||||
print(f" Size: {size_int / (1024 * 1024):.2f} MB")
|
||||
elif size_int > 1024:
|
||||
print(f" Size: {size_int / 1024:.2f} KB")
|
||||
else:
|
||||
print(f" Size: {size_int} bytes")
|
||||
except (ValueError, TypeError):
|
||||
print(f" Size: {size}")
|
||||
|
||||
# Show dates
|
||||
if "CreationDate" in item:
|
||||
print(f" Created: {item['CreationDate']}")
|
||||
if "ContentChangeDate" in item:
|
||||
print(f" Modified: {item['ContentChangeDate']}")
|
||||
|
||||
# Count by type
|
||||
type_counts = {}
|
||||
for item in results:
|
||||
content_type = item.get("ContentType", "unknown")
|
||||
type_counts[content_type] = type_counts.get(content_type, 0) + 1
|
||||
|
||||
print(f"\nTotal items saved: {len(results)}")
|
||||
|
||||
if type_counts:
|
||||
print("\nTop content types:")
|
||||
for ct, count in sorted(type_counts.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
print(f" {ct}: {count} items")
|
||||
|
||||
# Count by folder
|
||||
folder_counts = {}
|
||||
for item in results:
|
||||
path = item.get("Path", "")
|
||||
for folder in SEARCH_FOLDERS:
|
||||
# Build the full folder path
|
||||
if folder.startswith("/"):
|
||||
folder_path = folder
|
||||
else:
|
||||
folder_path = os.path.join(home_dir, folder)
|
||||
|
||||
if path.startswith(folder_path):
|
||||
folder_counts[folder] = folder_counts.get(folder, 0) + 1
|
||||
break
|
||||
|
||||
if folder_counts:
|
||||
print("\nItems by location:")
|
||||
for folder, count in sorted(folder_counts.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {folder}: {count} items")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments
|
||||
if len(sys.argv) > 1:
|
||||
try:
|
||||
max_items = int(sys.argv[1])
|
||||
except ValueError:
|
||||
print("Usage: python spot.py [number_of_items]")
|
||||
print("Default: 10 items")
|
||||
sys.exit(1)
|
||||
else:
|
||||
max_items = 10
|
||||
|
||||
output_file = sys.argv[2] if len(sys.argv) > 2 else "spotlight_dump.json"
|
||||
|
||||
# Run dump
|
||||
dump_spotlight_data(max_items=max_items, output_file=output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
apps/slack_data/__init__.py
Normal file
1
apps/slack_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Slack MCP data integration for LEANN
|
||||
516
apps/slack_data/slack_mcp_reader.py
Normal file
516
apps/slack_data/slack_mcp_reader.py
Normal file
@@ -0,0 +1,516 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Slack MCP Reader for LEANN
|
||||
|
||||
This module provides functionality to connect to Slack MCP servers and fetch message data
|
||||
for indexing in LEANN. It supports various Slack MCP server implementations and provides
|
||||
flexible message processing options.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackMCPReader:
|
||||
"""
|
||||
Reader for Slack data via MCP (Model Context Protocol) servers.
|
||||
|
||||
This class connects to Slack MCP servers to fetch message data and convert it
|
||||
into a format suitable for LEANN indexing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_command: str,
|
||||
workspace_name: Optional[str] = None,
|
||||
concatenate_conversations: bool = True,
|
||||
max_messages_per_conversation: int = 100,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Initialize the Slack MCP Reader.
|
||||
|
||||
Args:
|
||||
mcp_server_command: Command to start the MCP server (e.g., 'slack-mcp-server')
|
||||
workspace_name: Optional workspace name to filter messages
|
||||
concatenate_conversations: Whether to group messages by channel/thread
|
||||
max_messages_per_conversation: Maximum messages to include per conversation
|
||||
max_retries: Maximum number of retries for failed operations
|
||||
retry_delay: Initial delay between retries in seconds
|
||||
"""
|
||||
self.mcp_server_command = mcp_server_command
|
||||
self.workspace_name = workspace_name
|
||||
self.concatenate_conversations = concatenate_conversations
|
||||
self.max_messages_per_conversation = max_messages_per_conversation
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.mcp_process = None
|
||||
|
||||
async def start_mcp_server(self):
|
||||
"""Start the MCP server process."""
|
||||
try:
|
||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
||||
*self.mcp_server_command.split(),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}")
|
||||
raise
|
||||
|
||||
async def stop_mcp_server(self):
|
||||
"""Stop the MCP server process."""
|
||||
if self.mcp_process:
|
||||
self.mcp_process.terminate()
|
||||
await self.mcp_process.wait()
|
||||
logger.info("Stopped MCP server")
|
||||
|
||||
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send a request to the MCP server and get response."""
|
||||
if not self.mcp_process:
|
||||
raise RuntimeError("MCP server not started")
|
||||
|
||||
request_json = json.dumps(request) + "\n"
|
||||
self.mcp_process.stdin.write(request_json.encode())
|
||||
await self.mcp_process.stdin.drain()
|
||||
|
||||
response_line = await self.mcp_process.stdout.readline()
|
||||
if not response_line:
|
||||
raise RuntimeError("No response from MCP server")
|
||||
|
||||
return json.loads(response_line.decode().strip())
|
||||
|
||||
async def initialize_mcp_connection(self):
|
||||
"""Initialize the MCP connection."""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "leann-slack-reader", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(init_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
||||
|
||||
logger.info("MCP connection initialized successfully")
|
||||
|
||||
async def list_available_tools(self) -> list[dict[str, Any]]:
|
||||
"""List available tools from the MCP server."""
|
||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
||||
|
||||
response = await self.send_mcp_request(list_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
||||
|
||||
return response.get("result", {}).get("tools", [])
|
||||
|
||||
def _is_cache_sync_error(self, error: dict) -> bool:
|
||||
"""Check if the error is related to users cache not being ready."""
|
||||
if isinstance(error, dict):
|
||||
message = error.get("message", "").lower()
|
||||
return (
|
||||
"users cache is not ready" in message or "sync process is still running" in message
|
||||
)
|
||||
return False
|
||||
|
||||
async def _retry_with_backoff(self, func, *args, **kwargs):
|
||||
"""Retry a function with exponential backoff, especially for cache sync issues."""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# Check if this is a cache sync error
|
||||
error_dict = {}
|
||||
if hasattr(e, "args") and e.args and isinstance(e.args[0], dict):
|
||||
error_dict = e.args[0]
|
||||
elif "Failed to fetch messages" in str(e):
|
||||
# Try to extract error from the exception message
|
||||
import re
|
||||
|
||||
match = re.search(r"'error':\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
else:
|
||||
# Try alternative format
|
||||
match = re.search(r"Failed to fetch messages:\s*(\{[^}]+\})", str(e))
|
||||
if match:
|
||||
try:
|
||||
error_dict = ast.literal_eval(match.group(1))
|
||||
except (ValueError, SyntaxError):
|
||||
pass
|
||||
|
||||
if self._is_cache_sync_error(error_dict):
|
||||
if attempt < self.max_retries:
|
||||
delay = self.retry_delay * (2**attempt) # Exponential backoff
|
||||
logger.info(
|
||||
f"Cache sync not ready, waiting {delay:.1f}s before retry {attempt + 1}/{self.max_retries}"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cache sync still not ready after {self.max_retries} retries, giving up"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a cache sync error, don't retry
|
||||
break
|
||||
|
||||
# If we get here, all retries failed or it's not a retryable error
|
||||
if last_exception is not None:
|
||||
raise last_exception
|
||||
raise RuntimeError("Unexpected error: no exception captured during retry loop")
|
||||
|
||||
async def fetch_slack_messages(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch Slack messages using MCP tools with retry logic for cache sync issues.
|
||||
|
||||
Args:
|
||||
channel: Optional channel name to filter messages
|
||||
limit: Maximum number of messages to fetch
|
||||
|
||||
Returns:
|
||||
List of message dictionaries
|
||||
"""
|
||||
return await self._retry_with_backoff(self._fetch_slack_messages_impl, channel, limit)
|
||||
|
||||
async def _fetch_slack_messages_impl(
|
||||
self, channel: Optional[str] = None, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Internal implementation of fetch_slack_messages without retry logic.
|
||||
"""
|
||||
# This is a generic implementation - specific MCP servers may have different tool names
|
||||
# Common tool names might be: 'get_messages', 'list_messages', 'fetch_channel_history'
|
||||
|
||||
tools = await self.list_available_tools()
|
||||
logger.info(f"Available tools: {[tool.get('name') for tool in tools]}")
|
||||
message_tool = None
|
||||
|
||||
# Look for a tool that can fetch messages - prioritize conversations_history
|
||||
message_tool = None
|
||||
|
||||
# First, try to find conversations_history specifically
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if "conversations_history" in tool_name:
|
||||
message_tool = tool
|
||||
logger.info(f"Found conversations_history tool: {tool}")
|
||||
break
|
||||
|
||||
# If not found, look for other message-fetching tools
|
||||
if not message_tool:
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if any(
|
||||
keyword in tool_name
|
||||
for keyword in ["conversations_search", "message", "history"]
|
||||
):
|
||||
message_tool = tool
|
||||
break
|
||||
|
||||
if not message_tool:
|
||||
raise RuntimeError("No message fetching tool found in MCP server")
|
||||
|
||||
# Prepare tool call parameters
|
||||
tool_params = {"limit": "180d"} # Use 180 days to get older messages
|
||||
if channel:
|
||||
# For conversations_history, use channel_id parameter
|
||||
if message_tool["name"] == "conversations_history":
|
||||
tool_params["channel_id"] = channel
|
||||
else:
|
||||
# Try common parameter names for channel specification
|
||||
for param_name in ["channel", "channel_id", "channel_name"]:
|
||||
tool_params[param_name] = channel
|
||||
break
|
||||
|
||||
logger.info(f"Tool parameters: {tool_params}")
|
||||
|
||||
fetch_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {"name": message_tool["name"], "arguments": tool_params},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(fetch_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to fetch messages: {response['error']}")
|
||||
|
||||
# Extract messages from response - format may vary by MCP server
|
||||
result = response.get("result", {})
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
# Some MCP servers return content as a list
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
try:
|
||||
messages = json.loads(content["text"])
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, try to parse as CSV format (Slack MCP server format)
|
||||
text_content = content.get("text", "")
|
||||
messages = self._parse_csv_messages(
|
||||
text_content if text_content else "", channel or "unknown"
|
||||
)
|
||||
else:
|
||||
messages = result["content"]
|
||||
else:
|
||||
# Direct message format
|
||||
messages = result.get("messages", [result])
|
||||
|
||||
return messages if isinstance(messages, list) else [messages]
|
||||
|
||||
def _parse_csv_messages(self, csv_text: str, channel: str) -> list[dict[str, Any]]:
|
||||
"""Parse CSV format messages from Slack MCP server."""
|
||||
import csv
|
||||
import io
|
||||
|
||||
messages = []
|
||||
try:
|
||||
# Split by lines and process each line as a CSV row
|
||||
lines = csv_text.strip().split("\n")
|
||||
if not lines:
|
||||
return messages
|
||||
|
||||
# Skip header line if it exists
|
||||
start_idx = 0
|
||||
if lines[0].startswith("MsgID,UserID,UserName"):
|
||||
start_idx = 1
|
||||
|
||||
for line in lines[start_idx:]:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Parse CSV line
|
||||
reader = csv.reader(io.StringIO(line))
|
||||
try:
|
||||
row = next(reader)
|
||||
if len(row) >= 7: # Ensure we have enough columns
|
||||
message = {
|
||||
"ts": row[0],
|
||||
"user": row[1],
|
||||
"username": row[2],
|
||||
"real_name": row[3],
|
||||
"channel": row[4],
|
||||
"thread_ts": row[5],
|
||||
"text": row[6],
|
||||
"time": row[7] if len(row) > 7 else "",
|
||||
"reactions": row[8] if len(row) > 8 else "",
|
||||
"cursor": row[9] if len(row) > 9 else "",
|
||||
}
|
||||
messages.append(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse CSV line: {line[:100]}... Error: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse CSV messages: {e}")
|
||||
# Fallback: treat entire text as one message
|
||||
messages = [{"text": csv_text, "channel": channel or "unknown"}]
|
||||
|
||||
return messages
|
||||
|
||||
def _format_message(self, message: dict[str, Any]) -> str:
|
||||
"""Format a single message for indexing."""
|
||||
text = message.get("text", "")
|
||||
user = message.get("user", message.get("username", "Unknown"))
|
||||
channel = message.get("channel", message.get("channel_name", "Unknown"))
|
||||
timestamp = message.get("ts", message.get("timestamp", ""))
|
||||
|
||||
# Format timestamp if available
|
||||
formatted_time = ""
|
||||
if timestamp:
|
||||
try:
|
||||
import datetime
|
||||
|
||||
if isinstance(timestamp, str) and "." in timestamp:
|
||||
dt = datetime.datetime.fromtimestamp(float(timestamp))
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
elif isinstance(timestamp, (int, float)):
|
||||
dt = datetime.datetime.fromtimestamp(timestamp)
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
formatted_time = str(timestamp)
|
||||
except (ValueError, TypeError):
|
||||
formatted_time = str(timestamp)
|
||||
|
||||
# Build formatted message
|
||||
parts = []
|
||||
if channel:
|
||||
parts.append(f"Channel: #{channel}")
|
||||
if user:
|
||||
parts.append(f"User: {user}")
|
||||
if formatted_time:
|
||||
parts.append(f"Time: {formatted_time}")
|
||||
if text:
|
||||
parts.append(f"Message: {text}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _create_concatenated_content(self, messages: list[dict[str, Any]], channel: str) -> str:
|
||||
"""Create concatenated content from multiple messages in a channel."""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Sort messages by timestamp if available
|
||||
try:
|
||||
messages.sort(key=lambda x: float(x.get("ts", x.get("timestamp", 0))))
|
||||
except (ValueError, TypeError):
|
||||
pass # Keep original order if timestamps aren't numeric
|
||||
|
||||
# Limit messages per conversation
|
||||
if len(messages) > self.max_messages_per_conversation:
|
||||
messages = messages[-self.max_messages_per_conversation :]
|
||||
|
||||
# Create header
|
||||
content_parts = [
|
||||
f"Slack Channel: #{channel}",
|
||||
f"Message Count: {len(messages)}",
|
||||
f"Workspace: {self.workspace_name or 'Unknown'}",
|
||||
"=" * 50,
|
||||
"",
|
||||
]
|
||||
|
||||
# Add messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
content_parts.append(formatted_msg)
|
||||
content_parts.append("-" * 30)
|
||||
content_parts.append("")
|
||||
|
||||
return "\n".join(content_parts)
|
||||
|
||||
async def get_all_channels(self) -> list[str]:
|
||||
"""Get list of all available channels."""
|
||||
try:
|
||||
channels_list_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "channels_list", "arguments": {}},
|
||||
}
|
||||
channels_response = await self.send_mcp_request(channels_list_request)
|
||||
if "result" in channels_response:
|
||||
result = channels_response["result"]
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
# Parse the channels from the response
|
||||
channels = []
|
||||
lines = content["text"].split("\n")
|
||||
for line in lines:
|
||||
if line.strip() and ("#" in line or "C" in line[:10]):
|
||||
# Extract channel ID or name
|
||||
parts = line.split()
|
||||
for part in parts:
|
||||
if part.startswith("C") and len(part) > 5:
|
||||
channels.append(part)
|
||||
elif part.startswith("#"):
|
||||
channels.append(part[1:]) # Remove #
|
||||
logger.info(f"Found {len(channels)} channels: {channels}")
|
||||
return channels
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get channels list: {e}")
|
||||
return []
|
||||
|
||||
async def read_slack_data(self, channels: Optional[list[str]] = None) -> list[str]:
|
||||
"""
|
||||
Read Slack data and return formatted text chunks.
|
||||
|
||||
Args:
|
||||
channels: Optional list of channel names to fetch. If None, fetches from all available channels.
|
||||
|
||||
Returns:
|
||||
List of formatted text chunks ready for LEANN indexing
|
||||
"""
|
||||
try:
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
|
||||
all_texts = []
|
||||
|
||||
if channels:
|
||||
# Fetch specific channels
|
||||
for channel in channels:
|
||||
try:
|
||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||
if messages:
|
||||
if self.concatenate_conversations:
|
||||
text_content = self._create_concatenated_content(messages, channel)
|
||||
if text_content.strip():
|
||||
all_texts.append(text_content)
|
||||
else:
|
||||
# Process individual messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
all_texts.append(formatted_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||
continue
|
||||
else:
|
||||
# Fetch from all available channels
|
||||
logger.info("Fetching from all available channels...")
|
||||
all_channels = await self.get_all_channels()
|
||||
|
||||
if not all_channels:
|
||||
# Fallback to common channel names if we can't get the list
|
||||
all_channels = ["general", "random", "announcements", "C0GN5BX0F"]
|
||||
logger.info(f"Using fallback channels: {all_channels}")
|
||||
|
||||
for channel in all_channels:
|
||||
try:
|
||||
logger.info(f"Searching channel: {channel}")
|
||||
messages = await self.fetch_slack_messages(channel=channel, limit=1000)
|
||||
if messages:
|
||||
if self.concatenate_conversations:
|
||||
text_content = self._create_concatenated_content(messages, channel)
|
||||
if text_content.strip():
|
||||
all_texts.append(text_content)
|
||||
else:
|
||||
# Process individual messages
|
||||
for message in messages:
|
||||
formatted_msg = self._format_message(message)
|
||||
if formatted_msg.strip():
|
||||
all_texts.append(formatted_msg)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch messages from channel {channel}: {e}")
|
||||
continue
|
||||
|
||||
return all_texts
|
||||
|
||||
finally:
|
||||
await self.stop_mcp_server()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.stop_mcp_server()
|
||||
229
apps/slack_rag.py
Normal file
229
apps/slack_rag.py
Normal file
@@ -0,0 +1,229 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Slack RAG Application with MCP Support
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on Slack messages
|
||||
by connecting to Slack MCP servers to fetch live data and index it in LEANN.
|
||||
|
||||
Usage:
|
||||
python -m apps.slack_rag --mcp-server "slack-mcp-server" --query "What did the team discuss about the project?"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.slack_data.slack_mcp_reader import SlackMCPReader
|
||||
|
||||
|
||||
class SlackMCPRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for Slack messages via MCP servers.
|
||||
|
||||
This class provides a complete RAG pipeline for Slack data, including
|
||||
MCP server connection, data fetching, indexing, and interactive chat.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Slack MCP RAG",
|
||||
description="RAG application for Slack messages via MCP servers",
|
||||
default_index_name="slack_messages",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add Slack MCP-specific arguments."""
|
||||
parser.add_argument(
|
||||
"--mcp-server",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Command to start the Slack MCP server (e.g., 'slack-mcp-server' or 'npx slack-mcp-server')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--workspace-name",
|
||||
type=str,
|
||||
help="Slack workspace name for better organization and filtering",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--channels",
|
||||
nargs="+",
|
||||
help="Specific Slack channels to index (e.g., general random). If not specified, fetches from all available channels",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--concatenate-conversations",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Group messages by channel/thread for better context (default: True)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-concatenate-conversations",
|
||||
action="store_true",
|
||||
help="Process individual messages instead of grouping by channel",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-messages-per-channel",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of messages to include per channel (default: 100)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-connection",
|
||||
action="store_true",
|
||||
help="Test MCP server connection and list available tools without indexing",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-retries",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Maximum number of retries for failed operations (default: 5)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Initial delay between retries in seconds (default: 2.0)",
|
||||
)
|
||||
|
||||
async def test_mcp_connection(self, args) -> bool:
|
||||
"""Test the MCP server connection and display available tools."""
|
||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||
|
||||
try:
|
||||
reader = SlackMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
workspace_name=args.workspace_name,
|
||||
concatenate_conversations=not args.no_concatenate_conversations,
|
||||
max_messages_per_conversation=args.max_messages_per_channel,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
async with reader:
|
||||
tools = await reader.list_available_tools()
|
||||
|
||||
print("Successfully connected to MCP server!")
|
||||
print(f"Available tools ({len(tools)}):")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
name = tool.get("name", "Unknown")
|
||||
description = tool.get("description", "No description available")
|
||||
print(f"\n{i}. {name}")
|
||||
print(
|
||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||
)
|
||||
|
||||
# Show input schema if available
|
||||
schema = tool.get("inputSchema", {})
|
||||
if schema.get("properties"):
|
||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
||||
print(
|
||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to connect to MCP server: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure the MCP server is installed and accessible")
|
||||
print("2. Check if the server command is correct")
|
||||
print("3. Ensure you have proper authentication/credentials configured")
|
||||
print("4. Try running the MCP server command directly to test it")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Slack messages via MCP server."""
|
||||
print(f"Connecting to Slack MCP server: {args.mcp_server}")
|
||||
|
||||
if args.workspace_name:
|
||||
print(f"Workspace: {args.workspace_name}")
|
||||
|
||||
# Filter out empty strings from channels
|
||||
channels = [ch for ch in args.channels if ch.strip()] if args.channels else None
|
||||
|
||||
if channels:
|
||||
print(f"Channels: {', '.join(channels)}")
|
||||
else:
|
||||
print("Fetching from all available channels")
|
||||
|
||||
concatenate = not args.no_concatenate_conversations
|
||||
print(
|
||||
f"Processing mode: {'Concatenated conversations' if concatenate else 'Individual messages'}"
|
||||
)
|
||||
|
||||
try:
|
||||
reader = SlackMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
workspace_name=args.workspace_name,
|
||||
concatenate_conversations=concatenate,
|
||||
max_messages_per_conversation=args.max_messages_per_channel,
|
||||
max_retries=args.max_retries,
|
||||
retry_delay=args.retry_delay,
|
||||
)
|
||||
|
||||
texts = await reader.read_slack_data(channels=channels)
|
||||
|
||||
if not texts:
|
||||
print("No messages found! This could mean:")
|
||||
print("- The MCP server couldn't fetch messages")
|
||||
print("- The specified channels don't exist or are empty")
|
||||
print("- Authentication issues with the Slack workspace")
|
||||
return []
|
||||
|
||||
print(f"Successfully loaded {len(texts)} text chunks from Slack")
|
||||
|
||||
# Show sample of what was loaded
|
||||
if texts:
|
||||
sample_text = texts[0][:200] + "..." if len(texts[0]) > 200 else texts[0]
|
||||
print("\nSample content:")
|
||||
print("-" * 40)
|
||||
print(sample_text)
|
||||
print("-" * 40)
|
||||
|
||||
# Convert strings to dict format expected by base class
|
||||
return [{"text": text, "metadata": {"source": "slack"}} for text in texts]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading Slack data: {e}")
|
||||
print("\nThis might be due to:")
|
||||
print("- MCP server connection issues")
|
||||
print("- Authentication problems")
|
||||
print("- Network connectivity issues")
|
||||
print("- Incorrect channel names")
|
||||
raise
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point with MCP connection testing."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Test connection if requested
|
||||
if args.test_connection:
|
||||
success = await self.test_mcp_connection(args)
|
||||
if not success:
|
||||
return
|
||||
print(
|
||||
"MCP server is working! You can now run without --test-connection to start indexing."
|
||||
)
|
||||
return
|
||||
|
||||
# Run the standard RAG pipeline
|
||||
await super().run()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the Slack MCP RAG application."""
|
||||
app = SlackMCPRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1
apps/twitter_data/__init__.py
Normal file
1
apps/twitter_data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Twitter MCP data integration for LEANN
|
||||
295
apps/twitter_data/twitter_mcp_reader.py
Normal file
295
apps/twitter_data/twitter_mcp_reader.py
Normal file
@@ -0,0 +1,295 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Twitter MCP Reader for LEANN
|
||||
|
||||
This module provides functionality to connect to Twitter MCP servers and fetch bookmark data
|
||||
for indexing in LEANN. It supports various Twitter MCP server implementations and provides
|
||||
flexible bookmark processing options.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TwitterMCPReader:
|
||||
"""
|
||||
Reader for Twitter bookmark data via MCP (Model Context Protocol) servers.
|
||||
|
||||
This class connects to Twitter MCP servers to fetch bookmark data and convert it
|
||||
into a format suitable for LEANN indexing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mcp_server_command: str,
|
||||
username: Optional[str] = None,
|
||||
include_tweet_content: bool = True,
|
||||
include_metadata: bool = True,
|
||||
max_bookmarks: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the Twitter MCP Reader.
|
||||
|
||||
Args:
|
||||
mcp_server_command: Command to start the MCP server (e.g., 'twitter-mcp-server')
|
||||
username: Optional Twitter username to filter bookmarks
|
||||
include_tweet_content: Whether to include full tweet content
|
||||
include_metadata: Whether to include tweet metadata (likes, retweets, etc.)
|
||||
max_bookmarks: Maximum number of bookmarks to fetch
|
||||
"""
|
||||
self.mcp_server_command = mcp_server_command
|
||||
self.username = username
|
||||
self.include_tweet_content = include_tweet_content
|
||||
self.include_metadata = include_metadata
|
||||
self.max_bookmarks = max_bookmarks
|
||||
self.mcp_process = None
|
||||
|
||||
async def start_mcp_server(self):
|
||||
"""Start the MCP server process."""
|
||||
try:
|
||||
self.mcp_process = await asyncio.create_subprocess_exec(
|
||||
*self.mcp_server_command.split(),
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.info(f"Started MCP server: {self.mcp_server_command}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}")
|
||||
raise
|
||||
|
||||
async def stop_mcp_server(self):
|
||||
"""Stop the MCP server process."""
|
||||
if self.mcp_process:
|
||||
self.mcp_process.terminate()
|
||||
await self.mcp_process.wait()
|
||||
logger.info("Stopped MCP server")
|
||||
|
||||
async def send_mcp_request(self, request: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send a request to the MCP server and get response."""
|
||||
if not self.mcp_process:
|
||||
raise RuntimeError("MCP server not started")
|
||||
|
||||
request_json = json.dumps(request) + "\n"
|
||||
self.mcp_process.stdin.write(request_json.encode())
|
||||
await self.mcp_process.stdin.drain()
|
||||
|
||||
response_line = await self.mcp_process.stdout.readline()
|
||||
if not response_line:
|
||||
raise RuntimeError("No response from MCP server")
|
||||
|
||||
return json.loads(response_line.decode().strip())
|
||||
|
||||
async def initialize_mcp_connection(self):
|
||||
"""Initialize the MCP connection."""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "leann-twitter-reader", "version": "1.0.0"},
|
||||
},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(init_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"MCP initialization failed: {response['error']}")
|
||||
|
||||
logger.info("MCP connection initialized successfully")
|
||||
|
||||
async def list_available_tools(self) -> list[dict[str, Any]]:
|
||||
"""List available tools from the MCP server."""
|
||||
list_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {}}
|
||||
|
||||
response = await self.send_mcp_request(list_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to list tools: {response['error']}")
|
||||
|
||||
return response.get("result", {}).get("tools", [])
|
||||
|
||||
async def fetch_twitter_bookmarks(self, limit: Optional[int] = None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch Twitter bookmarks using MCP tools.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of bookmarks to fetch
|
||||
|
||||
Returns:
|
||||
List of bookmark dictionaries
|
||||
"""
|
||||
tools = await self.list_available_tools()
|
||||
bookmark_tool = None
|
||||
|
||||
# Look for a tool that can fetch bookmarks
|
||||
for tool in tools:
|
||||
tool_name = tool.get("name", "").lower()
|
||||
if any(keyword in tool_name for keyword in ["bookmark", "saved", "favorite"]):
|
||||
bookmark_tool = tool
|
||||
break
|
||||
|
||||
if not bookmark_tool:
|
||||
raise RuntimeError("No bookmark fetching tool found in MCP server")
|
||||
|
||||
# Prepare tool call parameters
|
||||
tool_params = {}
|
||||
if limit or self.max_bookmarks:
|
||||
tool_params["limit"] = limit or self.max_bookmarks
|
||||
if self.username:
|
||||
tool_params["username"] = self.username
|
||||
|
||||
fetch_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "tools/call",
|
||||
"params": {"name": bookmark_tool["name"], "arguments": tool_params},
|
||||
}
|
||||
|
||||
response = await self.send_mcp_request(fetch_request)
|
||||
if "error" in response:
|
||||
raise RuntimeError(f"Failed to fetch bookmarks: {response['error']}")
|
||||
|
||||
# Extract bookmarks from response
|
||||
result = response.get("result", {})
|
||||
if "content" in result and isinstance(result["content"], list):
|
||||
content = result["content"][0] if result["content"] else {}
|
||||
if "text" in content:
|
||||
try:
|
||||
bookmarks = json.loads(content["text"])
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, treat as plain text
|
||||
bookmarks = [{"text": content["text"], "source": "twitter"}]
|
||||
else:
|
||||
bookmarks = result["content"]
|
||||
else:
|
||||
bookmarks = result.get("bookmarks", result.get("tweets", [result]))
|
||||
|
||||
return bookmarks if isinstance(bookmarks, list) else [bookmarks]
|
||||
|
||||
def _format_bookmark(self, bookmark: dict[str, Any]) -> str:
|
||||
"""Format a single bookmark for indexing."""
|
||||
# Extract tweet information
|
||||
text = bookmark.get("text", bookmark.get("content", ""))
|
||||
author = bookmark.get(
|
||||
"author", bookmark.get("username", bookmark.get("user", {}).get("username", "Unknown"))
|
||||
)
|
||||
timestamp = bookmark.get("created_at", bookmark.get("timestamp", ""))
|
||||
url = bookmark.get("url", bookmark.get("tweet_url", ""))
|
||||
|
||||
# Extract metadata if available
|
||||
likes = bookmark.get("likes", bookmark.get("favorite_count", 0))
|
||||
retweets = bookmark.get("retweets", bookmark.get("retweet_count", 0))
|
||||
replies = bookmark.get("replies", bookmark.get("reply_count", 0))
|
||||
|
||||
# Build formatted bookmark
|
||||
parts = []
|
||||
|
||||
# Header
|
||||
parts.append("=== Twitter Bookmark ===")
|
||||
|
||||
if author:
|
||||
parts.append(f"Author: @{author}")
|
||||
|
||||
if timestamp:
|
||||
# Format timestamp if it's a standard format
|
||||
try:
|
||||
import datetime
|
||||
|
||||
if "T" in str(timestamp): # ISO format
|
||||
dt = datetime.datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
formatted_time = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
formatted_time = str(timestamp)
|
||||
parts.append(f"Date: {formatted_time}")
|
||||
except (ValueError, TypeError):
|
||||
parts.append(f"Date: {timestamp}")
|
||||
|
||||
if url:
|
||||
parts.append(f"URL: {url}")
|
||||
|
||||
# Tweet content
|
||||
if text and self.include_tweet_content:
|
||||
parts.append("")
|
||||
parts.append("Content:")
|
||||
parts.append(text)
|
||||
|
||||
# Metadata
|
||||
if self.include_metadata and any([likes, retweets, replies]):
|
||||
parts.append("")
|
||||
parts.append("Engagement:")
|
||||
if likes:
|
||||
parts.append(f" Likes: {likes}")
|
||||
if retweets:
|
||||
parts.append(f" Retweets: {retweets}")
|
||||
if replies:
|
||||
parts.append(f" Replies: {replies}")
|
||||
|
||||
# Extract hashtags and mentions if available
|
||||
hashtags = bookmark.get("hashtags", [])
|
||||
mentions = bookmark.get("mentions", [])
|
||||
|
||||
if hashtags or mentions:
|
||||
parts.append("")
|
||||
if hashtags:
|
||||
parts.append(f"Hashtags: {', '.join(hashtags)}")
|
||||
if mentions:
|
||||
parts.append(f"Mentions: {', '.join(mentions)}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
async def read_twitter_bookmarks(self) -> list[str]:
|
||||
"""
|
||||
Read Twitter bookmark data and return formatted text chunks.
|
||||
|
||||
Returns:
|
||||
List of formatted text chunks ready for LEANN indexing
|
||||
"""
|
||||
try:
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
|
||||
print(f"Fetching up to {self.max_bookmarks} bookmarks...")
|
||||
if self.username:
|
||||
print(f"Filtering for user: @{self.username}")
|
||||
|
||||
bookmarks = await self.fetch_twitter_bookmarks()
|
||||
|
||||
if not bookmarks:
|
||||
print("No bookmarks found")
|
||||
return []
|
||||
|
||||
print(f"Processing {len(bookmarks)} bookmarks...")
|
||||
|
||||
all_texts = []
|
||||
processed_count = 0
|
||||
|
||||
for bookmark in bookmarks:
|
||||
try:
|
||||
formatted_bookmark = self._format_bookmark(bookmark)
|
||||
if formatted_bookmark.strip():
|
||||
all_texts.append(formatted_bookmark)
|
||||
processed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to format bookmark: {e}")
|
||||
continue
|
||||
|
||||
print(f"Successfully processed {processed_count} bookmarks")
|
||||
return all_texts
|
||||
|
||||
finally:
|
||||
await self.stop_mcp_server()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry."""
|
||||
await self.start_mcp_server()
|
||||
await self.initialize_mcp_connection()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit."""
|
||||
await self.stop_mcp_server()
|
||||
197
apps/twitter_rag.py
Normal file
197
apps/twitter_rag.py
Normal file
@@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Twitter RAG Application with MCP Support
|
||||
|
||||
This application enables RAG (Retrieval-Augmented Generation) on Twitter bookmarks
|
||||
by connecting to Twitter MCP servers to fetch live data and index it in LEANN.
|
||||
|
||||
Usage:
|
||||
python -m apps.twitter_rag --mcp-server "twitter-mcp-server" --query "What articles did I bookmark about AI?"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from apps.base_rag_example import BaseRAGExample
|
||||
from apps.twitter_data.twitter_mcp_reader import TwitterMCPReader
|
||||
|
||||
|
||||
class TwitterMCPRAG(BaseRAGExample):
|
||||
"""
|
||||
RAG application for Twitter bookmarks via MCP servers.
|
||||
|
||||
This class provides a complete RAG pipeline for Twitter bookmark data, including
|
||||
MCP server connection, data fetching, indexing, and interactive chat.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Twitter MCP RAG",
|
||||
description="RAG application for Twitter bookmarks via MCP servers",
|
||||
default_index_name="twitter_bookmarks",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser: argparse.ArgumentParser):
|
||||
"""Add Twitter MCP-specific arguments."""
|
||||
parser.add_argument(
|
||||
"--mcp-server",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Command to start the Twitter MCP server (e.g., 'twitter-mcp-server' or 'npx twitter-mcp-server')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--username", type=str, help="Twitter username to filter bookmarks (without @)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-bookmarks",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum number of bookmarks to fetch (default: 1000)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-tweet-content",
|
||||
action="store_true",
|
||||
help="Exclude tweet content, only include metadata",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-metadata",
|
||||
action="store_true",
|
||||
help="Exclude engagement metadata (likes, retweets, etc.)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--test-connection",
|
||||
action="store_true",
|
||||
help="Test MCP server connection and list available tools without indexing",
|
||||
)
|
||||
|
||||
async def test_mcp_connection(self, args) -> bool:
|
||||
"""Test the MCP server connection and display available tools."""
|
||||
print(f"Testing connection to MCP server: {args.mcp_server}")
|
||||
|
||||
try:
|
||||
reader = TwitterMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
username=args.username,
|
||||
include_tweet_content=not args.no_tweet_content,
|
||||
include_metadata=not args.no_metadata,
|
||||
max_bookmarks=args.max_bookmarks,
|
||||
)
|
||||
|
||||
async with reader:
|
||||
tools = await reader.list_available_tools()
|
||||
|
||||
print("\n✅ Successfully connected to MCP server!")
|
||||
print(f"Available tools ({len(tools)}):")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
name = tool.get("name", "Unknown")
|
||||
description = tool.get("description", "No description available")
|
||||
print(f"\n{i}. {name}")
|
||||
print(
|
||||
f" Description: {description[:100]}{'...' if len(description) > 100 else ''}"
|
||||
)
|
||||
|
||||
# Show input schema if available
|
||||
schema = tool.get("inputSchema", {})
|
||||
if schema.get("properties"):
|
||||
props = list(schema["properties"].keys())[:3] # Show first 3 properties
|
||||
print(
|
||||
f" Parameters: {', '.join(props)}{'...' if len(schema['properties']) > 3 else ''}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ Failed to connect to MCP server: {e}")
|
||||
print("\nTroubleshooting tips:")
|
||||
print("1. Make sure the Twitter MCP server is installed and accessible")
|
||||
print("2. Check if the server command is correct")
|
||||
print("3. Ensure you have proper Twitter API credentials configured")
|
||||
print("4. Verify your Twitter account has bookmarks to fetch")
|
||||
print("5. Try running the MCP server command directly to test it")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load Twitter bookmarks via MCP server."""
|
||||
print(f"Connecting to Twitter MCP server: {args.mcp_server}")
|
||||
|
||||
if args.username:
|
||||
print(f"Username filter: @{args.username}")
|
||||
|
||||
print(f"Max bookmarks: {args.max_bookmarks}")
|
||||
print(f"Include tweet content: {not args.no_tweet_content}")
|
||||
print(f"Include metadata: {not args.no_metadata}")
|
||||
|
||||
try:
|
||||
reader = TwitterMCPReader(
|
||||
mcp_server_command=args.mcp_server,
|
||||
username=args.username,
|
||||
include_tweet_content=not args.no_tweet_content,
|
||||
include_metadata=not args.no_metadata,
|
||||
max_bookmarks=args.max_bookmarks,
|
||||
)
|
||||
|
||||
texts = await reader.read_twitter_bookmarks()
|
||||
|
||||
if not texts:
|
||||
print("❌ No bookmarks found! This could mean:")
|
||||
print("- You don't have any bookmarks on Twitter")
|
||||
print("- The MCP server couldn't access your bookmarks")
|
||||
print("- Authentication issues with Twitter API")
|
||||
print("- The username filter didn't match any bookmarks")
|
||||
return []
|
||||
|
||||
print(f"✅ Successfully loaded {len(texts)} bookmarks from Twitter")
|
||||
|
||||
# Show sample of what was loaded
|
||||
if texts:
|
||||
sample_text = texts[0][:300] + "..." if len(texts[0]) > 300 else texts[0]
|
||||
print("\nSample bookmark:")
|
||||
print("-" * 50)
|
||||
print(sample_text)
|
||||
print("-" * 50)
|
||||
|
||||
# Convert strings to dict format expected by base class
|
||||
return [{"text": text, "metadata": {"source": "twitter"}} for text in texts]
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading Twitter bookmarks: {e}")
|
||||
print("\nThis might be due to:")
|
||||
print("- MCP server connection issues")
|
||||
print("- Twitter API authentication problems")
|
||||
print("- Network connectivity issues")
|
||||
print("- Rate limiting from Twitter API")
|
||||
raise
|
||||
|
||||
async def run(self):
|
||||
"""Main entry point with MCP connection testing."""
|
||||
args = self.parser.parse_args()
|
||||
|
||||
# Test connection if requested
|
||||
if args.test_connection:
|
||||
success = await self.test_mcp_connection(args)
|
||||
if not success:
|
||||
return
|
||||
print(
|
||||
"\n🎉 MCP server is working! You can now run without --test-connection to start indexing."
|
||||
)
|
||||
return
|
||||
|
||||
# Run the standard RAG pipeline
|
||||
await super().run()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point for the Twitter MCP RAG application."""
|
||||
app = TwitterMCPRAG()
|
||||
await app.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
190
apps/wechat_rag.py
Normal file
190
apps/wechat_rag.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
WeChat History RAG example using the unified interface.
|
||||
Supports WeChat chat history export and search.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from base_rag_example import BaseRAGExample
|
||||
|
||||
from .history_data.wechat_history import WeChatHistoryReader
|
||||
|
||||
|
||||
class WeChatRAG(BaseRAGExample):
|
||||
"""RAG example for WeChat chat history."""
|
||||
|
||||
def __init__(self):
|
||||
# Set default values BEFORE calling super().__init__
|
||||
self.max_items_default = -1 # Match original default
|
||||
self.embedding_model_default = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2" # Fast 384-dim model
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="WeChat History",
|
||||
description="Process and query WeChat chat history with LEANN",
|
||||
default_index_name="wechat_history_magic_test_11Debug_new",
|
||||
)
|
||||
|
||||
def _add_specific_arguments(self, parser):
|
||||
"""Add WeChat-specific arguments."""
|
||||
wechat_group = parser.add_argument_group("WeChat Parameters")
|
||||
wechat_group.add_argument(
|
||||
"--export-dir",
|
||||
type=str,
|
||||
default="./wechat_export",
|
||||
help="Directory to store WeChat exports (default: ./wechat_export)",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--force-export",
|
||||
action="store_true",
|
||||
help="Force re-export of WeChat data even if exports exist",
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-size", type=int, default=192, help="Text chunk size (default: 192)"
|
||||
)
|
||||
wechat_group.add_argument(
|
||||
"--chunk-overlap", type=int, default=64, help="Text chunk overlap (default: 64)"
|
||||
)
|
||||
|
||||
def _export_wechat_data(self, export_dir: Path) -> bool:
|
||||
"""Export WeChat data using wechattweak-cli."""
|
||||
print("Exporting WeChat data...")
|
||||
|
||||
# Check if WeChat is running
|
||||
try:
|
||||
result = subprocess.run(["pgrep", "WeChat"], capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print("WeChat is not running. Please start WeChat first.")
|
||||
return False
|
||||
except Exception:
|
||||
pass # pgrep might not be available on all systems
|
||||
|
||||
# Create export directory
|
||||
export_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Run export command
|
||||
cmd = ["packages/wechat-exporter/wechattweak-cli", "export", str(export_dir)]
|
||||
|
||||
try:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("WeChat data exported successfully!")
|
||||
return True
|
||||
else:
|
||||
print(f"Export failed: {result.stderr}")
|
||||
return False
|
||||
|
||||
except FileNotFoundError:
|
||||
print("\nError: wechattweak-cli not found!")
|
||||
print("Please install it first:")
|
||||
print(" sudo packages/wechat-exporter/wechattweak-cli install")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Export error: {e}")
|
||||
return False
|
||||
|
||||
async def load_data(self, args) -> list[dict[str, Any]]:
|
||||
"""Load WeChat history and convert to text chunks."""
|
||||
# Initialize WeChat reader with export capabilities
|
||||
reader = WeChatHistoryReader()
|
||||
|
||||
# Find existing exports or create new ones using the centralized method
|
||||
export_dirs = reader.find_or_export_wechat_data(args.export_dir)
|
||||
if not export_dirs:
|
||||
print("Failed to find or export WeChat data. Trying to find any existing exports...")
|
||||
# Try to find any existing exports in common locations
|
||||
export_dirs = reader.find_wechat_export_dirs()
|
||||
if not export_dirs:
|
||||
print("No WeChat data found. Please ensure WeChat exports exist.")
|
||||
return []
|
||||
|
||||
# Load documents from all found export directories
|
||||
all_documents = []
|
||||
total_processed = 0
|
||||
|
||||
for i, export_dir in enumerate(export_dirs):
|
||||
print(f"\nProcessing WeChat export {i + 1}/{len(export_dirs)}: {export_dir}")
|
||||
|
||||
try:
|
||||
# Apply max_items limit per export
|
||||
max_per_export = -1
|
||||
if args.max_items > 0:
|
||||
remaining = args.max_items - total_processed
|
||||
if remaining <= 0:
|
||||
break
|
||||
max_per_export = remaining
|
||||
|
||||
documents = reader.load_data(
|
||||
wechat_export_dir=str(export_dir),
|
||||
max_count=max_per_export,
|
||||
concatenate_messages=True, # Enable message concatenation for better context
|
||||
)
|
||||
|
||||
if documents:
|
||||
print(f"Loaded {len(documents)} chat documents from {export_dir}")
|
||||
all_documents.extend(documents)
|
||||
total_processed += len(documents)
|
||||
else:
|
||||
print(f"No documents loaded from {export_dir}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {export_dir}: {e}")
|
||||
continue
|
||||
|
||||
if not all_documents:
|
||||
print("No documents loaded from any source. Exiting.")
|
||||
return []
|
||||
|
||||
print(f"\nTotal loaded {len(all_documents)} chat documents from {len(export_dirs)} exports")
|
||||
print("now starting to split into text chunks ... take some time")
|
||||
|
||||
# Convert to text chunks with contact information
|
||||
all_texts = []
|
||||
for doc in all_documents:
|
||||
# Split the document into chunks
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
text_splitter = SentenceSplitter(
|
||||
chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
|
||||
)
|
||||
nodes = text_splitter.get_nodes_from_documents([doc])
|
||||
|
||||
for node in nodes:
|
||||
# Add contact information to each chunk
|
||||
contact_name = doc.metadata.get("contact_name", "Unknown")
|
||||
text = f"[Contact] means the message is from: {contact_name}\n" + node.get_content()
|
||||
all_texts.append(text)
|
||||
|
||||
print(f"Created {len(all_texts)} text chunks from {len(all_documents)} documents")
|
||||
return all_texts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# Check platform
|
||||
if sys.platform != "darwin":
|
||||
print("\n⚠️ Warning: WeChat export is only supported on macOS")
|
||||
print(" You can still query existing exports on other platforms\n")
|
||||
|
||||
# Example queries for WeChat RAG
|
||||
print("\n💬 WeChat History RAG Example")
|
||||
print("=" * 50)
|
||||
print("\nExample queries you can try:")
|
||||
print("- 'Show me conversations about travel plans'")
|
||||
print("- 'Find group chats about weekend activities'")
|
||||
print("- '我想买魔术师约翰逊的球衣,给我一些对应聊天记录?'")
|
||||
print("- 'What did we discuss about the project last month?'")
|
||||
print("\nNote: WeChat must be running for export to work\n")
|
||||
|
||||
rag = WeChatRAG()
|
||||
asyncio.run(rag.run())
|
||||
BIN
assets/claude_code_leann.png
Normal file
BIN
assets/claude_code_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 73 KiB |
BIN
assets/mcp_leann.png
Normal file
BIN
assets/mcp_leann.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 224 KiB |
BIN
assets/wechat_user_group.JPG
Normal file
BIN
assets/wechat_user_group.JPG
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 152 KiB |
@@ -1,13 +1,28 @@
|
||||
# 🧪 Leann Sanity Checks
|
||||
# 🧪 LEANN Benchmarks & Testing
|
||||
|
||||
This directory contains comprehensive sanity checks for the Leann system, ensuring all components work correctly across different configurations.
|
||||
This directory contains performance benchmarks and comprehensive tests for the LEANN system, including backend comparisons and sanity checks across different configurations.
|
||||
|
||||
## 📁 Test Files
|
||||
|
||||
### `diskann_vs_hnsw_speed_comparison.py`
|
||||
Performance comparison between DiskANN and HNSW backends:
|
||||
- ✅ **Search latency** comparison with both backends using recompute
|
||||
- ✅ **Index size** and **build time** measurements
|
||||
- ✅ **Score validity** testing (ensures no -inf scores)
|
||||
- ✅ **Configurable dataset sizes** for different scales
|
||||
|
||||
```bash
|
||||
# Quick comparison with 500 docs, 10 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py
|
||||
|
||||
# Large-scale comparison with 2000 docs, 20 queries
|
||||
python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20
|
||||
```
|
||||
|
||||
### `test_distance_functions.py`
|
||||
Tests all supported distance functions across DiskANN backend:
|
||||
- ✅ **MIPS** (Maximum Inner Product Search)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **L2** (Euclidean Distance)
|
||||
- ✅ **Cosine** (Cosine Similarity)
|
||||
|
||||
```bash
|
||||
@@ -27,7 +42,7 @@ uv run python tests/sanity_checks/test_l2_verification.py
|
||||
### `test_sanity_check.py`
|
||||
Comprehensive end-to-end verification including:
|
||||
- Distance function testing
|
||||
- Embedding model compatibility
|
||||
- Embedding model compatibility
|
||||
- Search result correctness validation
|
||||
- Backend integration testing
|
||||
|
||||
@@ -64,7 +79,7 @@ When all tests pass, you should see:
|
||||
```
|
||||
📊 测试结果总结:
|
||||
mips : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
l2 : ✅ 通过
|
||||
cosine : ✅ 通过
|
||||
|
||||
🎉 测试完成!
|
||||
@@ -98,7 +113,7 @@ pkill -f "embedding_server"
|
||||
|
||||
### Typical Timing (3 documents, consumer hardware):
|
||||
- **Index Building**: 2-5 seconds per distance function
|
||||
- **Search Query**: 50-200ms
|
||||
- **Search Query**: 50-200ms
|
||||
- **Recompute Mode**: 5-15 seconds (higher accuracy)
|
||||
|
||||
### Memory Usage:
|
||||
@@ -117,4 +132,4 @@ These tests are designed to be run in automated environments:
|
||||
uv run python tests/sanity_checks/test_l2_verification.py
|
||||
```
|
||||
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
The tests are deterministic and should produce consistent results across different platforms.
|
||||
0
benchmarks/__init__.py
Normal file
0
benchmarks/__init__.py
Normal file
@@ -1,43 +1,46 @@
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
import torch
|
||||
from mlx_lm import load
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_NAME_TORCH = "Qwen/Qwen3-Embedding-0.6B"
|
||||
MODEL_NAME_MLX = "mlx-community/Qwen3-Embedding-0.6B-4bit-DWQ"
|
||||
BATCH_SIZES = [1, 8, 16, 32, 64, 128]
|
||||
NUM_RUNS = 10 # Number of runs to average for each batch size
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
WARMUP_RUNS = 2 # Number of warm-up runs
|
||||
|
||||
# --- Generate Dummy Data ---
|
||||
DUMMY_SENTENCES = ["This is a test sentence for benchmarking." * 5] * max(BATCH_SIZES)
|
||||
|
||||
# --- Benchmark Functions ---b
|
||||
|
||||
|
||||
def benchmark_torch(model, sentences):
|
||||
start_time = time.time()
|
||||
model.encode(sentences, convert_to_numpy=True)
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
def benchmark_mlx(model, tokenizer, sentences):
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Tokenize sentences using MLX tokenizer
|
||||
tokens = []
|
||||
for sentence in sentences:
|
||||
token_ids = tokenizer.encode(sentence)
|
||||
tokens.append(token_ids)
|
||||
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_len = max(len(t) for t in tokens)
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
|
||||
|
||||
for token_seq in tokens:
|
||||
# Pad sequence
|
||||
padded = token_seq + [tokenizer.eos_token_id] * (max_len - len(token_seq))
|
||||
@@ -45,24 +48,25 @@ def benchmark_mlx(model, tokenizer, sentences):
|
||||
# Create attention mask (1 for real tokens, 0 for padding)
|
||||
mask = [1] * len(token_seq) + [0] * (max_len - len(token_seq))
|
||||
attention_mask.append(mask)
|
||||
|
||||
|
||||
# Convert to MLX arrays
|
||||
input_ids = mx.array(input_ids)
|
||||
attention_mask = mx.array(attention_mask)
|
||||
|
||||
|
||||
# Get embeddings
|
||||
embeddings = model(input_ids)
|
||||
|
||||
|
||||
# Mean pooling
|
||||
mask = mx.expand_dims(attention_mask, -1)
|
||||
sum_embeddings = (embeddings * mask).sum(axis=1)
|
||||
sum_mask = mask.sum(axis=1)
|
||||
_ = sum_embeddings / sum_mask
|
||||
|
||||
|
||||
mx.eval() # Ensure computation is finished
|
||||
end_time = time.time()
|
||||
return (end_time - start_time) * 1000 # Return time in ms
|
||||
|
||||
|
||||
# --- Main Execution ---
|
||||
def main():
|
||||
print("--- Initializing Models ---")
|
||||
@@ -92,13 +96,15 @@ def main():
|
||||
for batch_size in BATCH_SIZES:
|
||||
print(f"Benchmarking batch size: {batch_size}")
|
||||
sentences_batch = DUMMY_SENTENCES[:batch_size]
|
||||
|
||||
|
||||
# Benchmark PyTorch
|
||||
torch_times = [benchmark_torch(model_torch, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
results_torch.append(np.mean(torch_times))
|
||||
|
||||
|
||||
# Benchmark MLX
|
||||
mlx_times = [benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)]
|
||||
mlx_times = [
|
||||
benchmark_mlx(model_mlx, tokenizer_mlx, sentences_batch) for _ in range(NUM_RUNS)
|
||||
]
|
||||
results_mlx.append(np.mean(mlx_times))
|
||||
|
||||
print("\n--- Benchmark Results (Average time per batch in ms) ---")
|
||||
@@ -109,20 +115,27 @@ def main():
|
||||
# --- Plotting ---
|
||||
print("\n--- Generating Plot ---")
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(BATCH_SIZES, results_torch, marker='o', linestyle='-', label=f'PyTorch ({device})')
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker='s', linestyle='-', label='MLX')
|
||||
plt.plot(
|
||||
BATCH_SIZES,
|
||||
results_torch,
|
||||
marker="o",
|
||||
linestyle="-",
|
||||
label=f"PyTorch ({device})",
|
||||
)
|
||||
plt.plot(BATCH_SIZES, results_mlx, marker="s", linestyle="-", label="MLX")
|
||||
|
||||
plt.title(f'Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}')
|
||||
plt.title(f"Embedding Performance: MLX vs PyTorch\nModel: {MODEL_NAME_TORCH}")
|
||||
plt.xlabel("Batch Size")
|
||||
plt.ylabel("Average Time per Batch (ms)")
|
||||
plt.xticks(BATCH_SIZES)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
|
||||
# Save the plot
|
||||
output_filename = "embedding_benchmark.png"
|
||||
plt.savefig(output_filename)
|
||||
print(f"Plot saved to {output_filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
148
benchmarks/benchmark_no_recompute.py
Normal file
148
benchmarks/benchmark_no_recompute.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
def _meta_exists(index_path: str) -> bool:
|
||||
p = Path(index_path)
|
||||
return (p.parent / f"{p.stem}.meta.json").exists()
|
||||
|
||||
|
||||
def ensure_index(index_path: str, backend_name: str, num_docs: int, is_recompute: bool) -> None:
|
||||
# if _meta_exists(index_path):
|
||||
# return
|
||||
kwargs = {}
|
||||
if backend_name == "hnsw":
|
||||
kwargs["is_compact"] = is_recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model=os.getenv("LEANN_EMBED_MODEL", "facebook/contriever"),
|
||||
embedding_mode=os.getenv("LEANN_EMBED_MODE", "sentence-transformers"),
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=is_recompute,
|
||||
num_threads=4,
|
||||
**kwargs,
|
||||
)
|
||||
for i in range(num_docs):
|
||||
builder.add_text(
|
||||
f"This is a test document number {i}. It contains some repeated text for benchmarking."
|
||||
)
|
||||
builder.build_index(index_path)
|
||||
|
||||
|
||||
def _bench_group(
|
||||
index_path: str,
|
||||
recompute: bool,
|
||||
query: str,
|
||||
repeats: int,
|
||||
complexity: int = 32,
|
||||
top_k: int = 10,
|
||||
) -> float:
|
||||
# Independent searcher per group; fixed port when recompute
|
||||
searcher = LeannSearcher(index_path=index_path)
|
||||
|
||||
# Warm-up once
|
||||
_ = searcher.search(
|
||||
query,
|
||||
top_k=top_k,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute,
|
||||
)
|
||||
|
||||
def _once() -> float:
|
||||
t0 = time.time()
|
||||
_ = searcher.search(
|
||||
query,
|
||||
top_k=top_k,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute,
|
||||
)
|
||||
return time.time() - t0
|
||||
|
||||
if repeats <= 1:
|
||||
t = _once()
|
||||
else:
|
||||
vals = [_once() for _ in range(repeats)]
|
||||
vals.sort()
|
||||
t = vals[len(vals) // 2]
|
||||
|
||||
searcher.cleanup()
|
||||
return t
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-docs", type=int, default=5000)
|
||||
parser.add_argument("--repeats", type=int, default=3)
|
||||
parser.add_argument("--complexity", type=int, default=32)
|
||||
args = parser.parse_args()
|
||||
|
||||
base = Path.cwd() / ".leann" / "indexes" / f"bench_n{args.num_docs}"
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
# ---------- Build HNSW variants ----------
|
||||
hnsw_r = str(base / f"hnsw_recompute_n{args.num_docs}.leann")
|
||||
hnsw_nr = str(base / f"hnsw_norecompute_n{args.num_docs}.leann")
|
||||
ensure_index(hnsw_r, "hnsw", args.num_docs, True)
|
||||
ensure_index(hnsw_nr, "hnsw", args.num_docs, False)
|
||||
|
||||
# ---------- Build DiskANN variants ----------
|
||||
diskann_r = str(base / "diskann_r.leann")
|
||||
diskann_nr = str(base / "diskann_nr.leann")
|
||||
ensure_index(diskann_r, "diskann", args.num_docs, True)
|
||||
ensure_index(diskann_nr, "diskann", args.num_docs, False)
|
||||
|
||||
# ---------- Helpers ----------
|
||||
def _size_for(prefix: str) -> int:
|
||||
p = Path(prefix)
|
||||
base_dir = p.parent
|
||||
stem = p.stem
|
||||
total = 0
|
||||
for f in base_dir.iterdir():
|
||||
if f.is_file() and f.name.startswith(stem):
|
||||
total += f.stat().st_size
|
||||
return total
|
||||
|
||||
# ---------- HNSW benchmark ----------
|
||||
t_hnsw_r = _bench_group(
|
||||
hnsw_r, True, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
t_hnsw_nr = _bench_group(
|
||||
hnsw_nr, False, "test document number 42", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
size_hnsw_r = _size_for(hnsw_r)
|
||||
size_hnsw_nr = _size_for(hnsw_nr)
|
||||
|
||||
print("Benchmark results (HNSW):")
|
||||
print(f" recompute=True: search_time={t_hnsw_r:.3f}s, size={size_hnsw_r / 1024 / 1024:.1f}MB")
|
||||
print(
|
||||
f" recompute=False: search_time={t_hnsw_nr:.3f}s, size={size_hnsw_nr / 1024 / 1024:.1f}MB"
|
||||
)
|
||||
print(" Expectation: no-recompute should be faster but larger on disk.")
|
||||
|
||||
# ---------- DiskANN benchmark ----------
|
||||
t_diskann_r = _bench_group(
|
||||
diskann_r, True, "DiskANN R test doc 123", repeats=args.repeats, complexity=args.complexity
|
||||
)
|
||||
t_diskann_nr = _bench_group(
|
||||
diskann_nr,
|
||||
False,
|
||||
"DiskANN NR test doc 123",
|
||||
repeats=args.repeats,
|
||||
complexity=args.complexity,
|
||||
)
|
||||
size_diskann_r = _size_for(diskann_r)
|
||||
size_diskann_nr = _size_for(diskann_nr)
|
||||
|
||||
print("\nBenchmark results (DiskANN):")
|
||||
print(f" build(recompute=True, partition): size={size_diskann_r / 1024 / 1024:.1f}MB")
|
||||
print(f" build(recompute=False): size={size_diskann_nr / 1024 / 1024:.1f}MB")
|
||||
print(f" search recompute=True (final rerank): {t_diskann_r:.3f}s")
|
||||
print(f" search recompute=False (PQ only): {t_diskann_nr:.3f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
23
benchmarks/bm25_diskann_baselines/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
BM25 vs DiskANN Baselines
|
||||
|
||||
```bash
|
||||
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/bm25_rpj_wiki/index_en_only/ benchmarks/data/indices/bm25_index/
|
||||
aws s3 sync s3://powerrag-diskann-rpj-wiki-20250824-224037-194d640c/diskann_rpj_wiki/ benchmarks/data/indices/diskann_rpj_wiki/
|
||||
```
|
||||
|
||||
- Dataset: `benchmarks/data/queries/nq_open.jsonl` (Natural Questions)
|
||||
- Machine-specific; results measured locally with the current repo.
|
||||
|
||||
DiskANN (NQ queries, search-only)
|
||||
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_diskann.py`
|
||||
- Settings: `recompute_embeddings=False`, embeddings precomputed (excluded from timing), batching off, caching off (`cache_mechanism=2`, `num_nodes_to_cache=0`)
|
||||
- Result: avg 0.011093 s/query, QPS 90.15 (p50 0.010731 s, p95 0.015000 s)
|
||||
|
||||
BM25
|
||||
- Command: `uv run --script benchmarks/bm25_diskann_baselines/run_bm25.py`
|
||||
- Settings: `k=10`, `k1=0.9`, `b=0.4`, queries=100
|
||||
- Result: avg 0.028589 s/query, QPS 34.97 (p50 0.026060 s, p90 0.043695 s, p95 0.053260 s, p99 0.055257 s)
|
||||
|
||||
Notes
|
||||
- DiskANN measures search-only latency on real NQ queries (embeddings computed beforehand and excluded from timing).
|
||||
- Use `benchmarks/bm25_diskann_baselines/run_diskann.py` for DiskANN; `benchmarks/bm25_diskann_baselines/run_bm25.py` for BM25.
|
||||
|
After Width: | Height: | Size: 1.3 KiB |
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
183
benchmarks/bm25_diskann_baselines/run_bm25.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "pyserini"
|
||||
# ]
|
||||
# ///
|
||||
# sudo pacman -S jdk21-openjdk
|
||||
# export JAVA_HOME=/usr/lib/jvm/java-21-openjdk
|
||||
# sudo archlinux-java status
|
||||
# sudo archlinux-java set java-21-openjdk
|
||||
# set -Ux JAVA_HOME /usr/lib/jvm/java-21-openjdk
|
||||
# fish_add_path --global $JAVA_HOME/bin
|
||||
# set -Ux LD_LIBRARY_PATH $JAVA_HOME/lib/server $LD_LIBRARY_PATH
|
||||
# which javac # Should be /usr/lib/jvm/java-21-openjdk/bin/javac
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from statistics import mean
|
||||
|
||||
|
||||
def load_queries(path: str, limit: int | None) -> list[str]:
|
||||
queries: list[str] = []
|
||||
# Try JSONL with a 'query' or 'text' field; fallback to plain text (one query per line)
|
||||
_, ext = os.path.splitext(path)
|
||||
if ext.lower() in {".jsonl", ".json"}:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
# Not strict JSONL? treat the whole line as the query
|
||||
queries.append(line)
|
||||
continue
|
||||
q = obj.get("query") or obj.get("text") or obj.get("question")
|
||||
if q:
|
||||
queries.append(str(q))
|
||||
else:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
s = line.strip()
|
||||
if s:
|
||||
queries.append(s)
|
||||
|
||||
if limit is not None and limit > 0:
|
||||
queries = queries[:limit]
|
||||
return queries
|
||||
|
||||
|
||||
def percentile(values: list[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
s = sorted(values)
|
||||
k = (len(s) - 1) * (p / 100.0)
|
||||
f = int(k)
|
||||
c = min(f + 1, len(s) - 1)
|
||||
if f == c:
|
||||
return s[f]
|
||||
return s[f] + (s[c] - s[f]) * (k - f)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="Standalone BM25 latency benchmark (Pyserini)")
|
||||
ap.add_argument(
|
||||
"--bm25-index",
|
||||
default="benchmarks/data/indices/bm25_index",
|
||||
help="Path to Pyserini Lucene index directory",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--queries",
|
||||
default="benchmarks/data/queries/nq_open.jsonl",
|
||||
help="Path to queries file (JSONL with 'query'/'text' or plain txt one-per-line)",
|
||||
)
|
||||
ap.add_argument("--k", type=int, default=10, help="Top-k to retrieve (default: 10)")
|
||||
ap.add_argument("--k1", type=float, default=0.9, help="BM25 k1 (default: 0.9)")
|
||||
ap.add_argument("--b", type=float, default=0.4, help="BM25 b (default: 0.4)")
|
||||
ap.add_argument("--limit", type=int, default=100, help="Max queries to run (default: 100)")
|
||||
ap.add_argument(
|
||||
"--warmup", type=int, default=5, help="Warmup queries not counted in latency (default: 5)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--fetch-docs", action="store_true", help="Also fetch doc contents (slower; default: off)"
|
||||
)
|
||||
ap.add_argument("--report", type=str, default=None, help="Optional JSON report path")
|
||||
args = ap.parse_args()
|
||||
|
||||
try:
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
except Exception:
|
||||
print("Pyserini not found. Install with: pip install pyserini", file=sys.stderr)
|
||||
raise
|
||||
|
||||
if not os.path.isdir(args.bm25_index):
|
||||
print(f"Index directory not found: {args.bm25_index}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
queries = load_queries(args.queries, args.limit)
|
||||
if not queries:
|
||||
print("No queries loaded.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Loaded {len(queries)} queries from {args.queries}")
|
||||
print(f"Opening BM25 index: {args.bm25_index}")
|
||||
searcher = LuceneSearcher(args.bm25_index)
|
||||
# Some builds of pyserini require explicit set_bm25; others ignore
|
||||
try:
|
||||
searcher.set_bm25(k1=args.k1, b=args.b)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
latencies: list[float] = []
|
||||
total_searches = 0
|
||||
|
||||
# Warmup
|
||||
for i in range(min(args.warmup, len(queries))):
|
||||
_ = searcher.search(queries[i], k=args.k)
|
||||
|
||||
t0 = time.time()
|
||||
for i, q in enumerate(queries):
|
||||
t1 = time.time()
|
||||
hits = searcher.search(q, k=args.k)
|
||||
t2 = time.time()
|
||||
latencies.append(t2 - t1)
|
||||
total_searches += 1
|
||||
|
||||
if args.fetch_docs:
|
||||
# Optional doc fetch to include I/O time
|
||||
for h in hits:
|
||||
try:
|
||||
_ = searcher.doc(h.docid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (i + 1) % 50 == 0:
|
||||
print(f"Processed {i + 1}/{len(queries)} queries")
|
||||
|
||||
t1 = time.time()
|
||||
total_time = t1 - t0
|
||||
|
||||
if latencies:
|
||||
avg = mean(latencies)
|
||||
p50 = percentile(latencies, 50)
|
||||
p90 = percentile(latencies, 90)
|
||||
p95 = percentile(latencies, 95)
|
||||
p99 = percentile(latencies, 99)
|
||||
qps = total_searches / total_time if total_time > 0 else 0.0
|
||||
else:
|
||||
avg = p50 = p90 = p95 = p99 = qps = 0.0
|
||||
|
||||
print("BM25 Latency Report")
|
||||
print(f" queries: {total_searches}")
|
||||
print(f" k: {args.k}, k1: {args.k1}, b: {args.b}")
|
||||
print(f" avg per query: {avg:.6f} s")
|
||||
print(f" p50/p90/p95/p99: {p50:.6f}/{p90:.6f}/{p95:.6f}/{p99:.6f} s")
|
||||
print(f" total time: {total_time:.3f} s, qps: {qps:.2f}")
|
||||
|
||||
if args.report:
|
||||
payload = {
|
||||
"queries": total_searches,
|
||||
"k": args.k,
|
||||
"k1": args.k1,
|
||||
"b": args.b,
|
||||
"avg_s": avg,
|
||||
"p50_s": p50,
|
||||
"p90_s": p90,
|
||||
"p95_s": p95,
|
||||
"p99_s": p99,
|
||||
"total_time_s": total_time,
|
||||
"qps": qps,
|
||||
"index_dir": os.path.abspath(args.bm25_index),
|
||||
"fetch_docs": bool(args.fetch_docs),
|
||||
}
|
||||
with open(args.report, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, indent=2)
|
||||
print(f"Saved report to {args.report}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
124
benchmarks/bm25_diskann_baselines/run_diskann.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "leann-backend-diskann"
|
||||
# ]
|
||||
# ///
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_queries(path: Path, limit: int | None) -> list[str]:
|
||||
out: list[str] = []
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
obj = json.loads(line)
|
||||
out.append(obj["query"])
|
||||
if limit and len(out) >= limit:
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(
|
||||
description="DiskANN baseline on real NQ queries (search-only timing)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--index-dir",
|
||||
default="benchmarks/data/indices/diskann_rpj_wiki",
|
||||
help="Directory containing DiskANN files",
|
||||
)
|
||||
ap.add_argument("--index-prefix", default="ann")
|
||||
ap.add_argument("--queries-file", default="benchmarks/data/queries/nq_open.jsonl")
|
||||
ap.add_argument("--num-queries", type=int, default=200)
|
||||
ap.add_argument("--top-k", type=int, default=10)
|
||||
ap.add_argument("--complexity", type=int, default=62)
|
||||
ap.add_argument("--threads", type=int, default=1)
|
||||
ap.add_argument("--beam-width", type=int, default=1)
|
||||
ap.add_argument("--cache-mechanism", type=int, default=2)
|
||||
ap.add_argument("--num-nodes-to-cache", type=int, default=0)
|
||||
args = ap.parse_args()
|
||||
|
||||
index_dir = Path(args.index_dir).resolve()
|
||||
if not index_dir.is_dir():
|
||||
raise SystemExit(f"Index dir not found: {index_dir}")
|
||||
|
||||
qpath = Path(args.queries_file).resolve()
|
||||
if not qpath.exists():
|
||||
raise SystemExit(f"Queries file not found: {qpath}")
|
||||
|
||||
queries = load_queries(qpath, args.num_queries)
|
||||
print(f"Loaded {len(queries)} queries from {qpath}")
|
||||
|
||||
# Compute embeddings once (exclude from timing)
|
||||
from leann.api import compute_embeddings as _compute
|
||||
|
||||
embs = _compute(
|
||||
queries,
|
||||
model_name="facebook/contriever-msmarco",
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
if embs.ndim != 2:
|
||||
raise SystemExit("Embedding compute failed or returned wrong shape")
|
||||
|
||||
# Build searcher
|
||||
from leann_backend_diskann.diskann_backend import DiskannSearcher as _DiskannSearcher
|
||||
|
||||
index_prefix_path = str(index_dir / args.index_prefix)
|
||||
searcher = _DiskannSearcher(
|
||||
index_prefix_path,
|
||||
num_threads=int(args.threads),
|
||||
cache_mechanism=int(args.cache_mechanism),
|
||||
num_nodes_to_cache=int(args.num_nodes_to_cache),
|
||||
)
|
||||
|
||||
# Warmup (not timed)
|
||||
_ = searcher.search(
|
||||
embs[0:1],
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
batch_recompute=False,
|
||||
dedup_node_dis=False,
|
||||
)
|
||||
|
||||
# Timed loop
|
||||
times: list[float] = []
|
||||
for i in range(embs.shape[0]):
|
||||
t0 = time.time()
|
||||
_ = searcher.search(
|
||||
embs[i : i + 1],
|
||||
top_k=args.top_k,
|
||||
complexity=args.complexity,
|
||||
beam_width=args.beam_width,
|
||||
prune_ratio=0.0,
|
||||
recompute_embeddings=False,
|
||||
batch_recompute=False,
|
||||
dedup_node_dis=False,
|
||||
)
|
||||
times.append(time.time() - t0)
|
||||
|
||||
times_sorted = sorted(times)
|
||||
avg = float(sum(times) / len(times))
|
||||
p50 = times_sorted[len(times) // 2]
|
||||
p95 = times_sorted[max(0, int(len(times) * 0.95) - 1)]
|
||||
|
||||
print("\nDiskANN (NQ, search-only) Report")
|
||||
print(f" queries: {len(times)}")
|
||||
print(
|
||||
f" k: {args.top_k}, complexity: {args.complexity}, beam_width: {args.beam_width}, threads: {args.threads}"
|
||||
)
|
||||
print(f" avg per query: {avg:.6f} s")
|
||||
print(f" p50/p95: {p50:.6f}/{p95:.6f} s")
|
||||
print(f" QPS: {1.0 / avg:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,14 +3,15 @@
|
||||
Memory comparison between Faiss HNSW and LEANN HNSW backend
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import psutil
|
||||
import gc
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
# Setup logging
|
||||
@@ -61,7 +62,7 @@ def test_faiss_hnsw():
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "examples/faiss_only.py"],
|
||||
[sys.executable, "benchmarks/faiss_only.py"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
@@ -83,9 +84,7 @@ def test_faiss_hnsw():
|
||||
|
||||
for line in lines:
|
||||
if "Peak Memory:" in line:
|
||||
peak_memory = float(
|
||||
line.split("Peak Memory:")[1].split("MB")[0].strip()
|
||||
)
|
||||
peak_memory = float(line.split("Peak Memory:")[1].split("MB")[0].strip())
|
||||
|
||||
return {"peak_memory": peak_memory}
|
||||
|
||||
@@ -111,13 +110,12 @@ def test_leann_hnsw():
|
||||
|
||||
tracker.checkpoint("After imports")
|
||||
|
||||
from leann.api import LeannBuilder
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
|
||||
# Load and parse documents
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -197,16 +195,14 @@ def test_leann_hnsw():
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
|
||||
# Load searcher
|
||||
searcher = LeannSearcher(index_path)
|
||||
tracker.checkpoint("After searcher loading")
|
||||
|
||||
|
||||
|
||||
print("Running search queries...")
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
@@ -304,21 +300,15 @@ def main():
|
||||
|
||||
print("\nLEANN vs Faiss Performance:")
|
||||
memory_saving = faiss_results["peak_memory"] - leann_results["peak_memory"]
|
||||
print(
|
||||
f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)"
|
||||
)
|
||||
print(f" Search Memory: {memory_ratio:.1f}x less ({memory_saving:.1f} MB saved)")
|
||||
|
||||
# Storage comparison
|
||||
if leann_storage_size > faiss_storage_size:
|
||||
storage_ratio = leann_storage_size / faiss_storage_size
|
||||
print(
|
||||
f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)"
|
||||
)
|
||||
print(f" Storage Size: {storage_ratio:.1f}x larger (LEANN uses more storage)")
|
||||
elif faiss_storage_size > leann_storage_size:
|
||||
storage_ratio = faiss_storage_size / leann_storage_size
|
||||
print(
|
||||
f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)"
|
||||
)
|
||||
print(f" Storage Size: {storage_ratio:.1f}x smaller (LEANN uses less storage)")
|
||||
else:
|
||||
print(" Storage Size: similar")
|
||||
else:
|
||||
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
0
data/README.md → benchmarks/data/README.md
Normal file → Executable file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
286
benchmarks/diskann_vs_hnsw_speed_comparison.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DiskANN vs HNSW Search Performance Comparison
|
||||
|
||||
This benchmark compares search performance between DiskANN and HNSW backends:
|
||||
- DiskANN: With graph partitioning enabled (is_recompute=True)
|
||||
- HNSW: With recompute enabled (is_recompute=True)
|
||||
- Tests performance across different dataset sizes
|
||||
- Measures search latency, recall, and index size
|
||||
"""
|
||||
|
||||
import gc
|
||||
import multiprocessing as mp
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Prefer 'fork' start method to avoid POSIX semaphore leaks on macOS
|
||||
try:
|
||||
mp.set_start_method("fork", force=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def create_test_texts(n_docs: int) -> list[str]:
|
||||
"""Create synthetic test documents for benchmarking."""
|
||||
np.random.seed(42)
|
||||
topics = [
|
||||
"machine learning and artificial intelligence",
|
||||
"natural language processing and text analysis",
|
||||
"computer vision and image recognition",
|
||||
"data science and statistical analysis",
|
||||
"deep learning and neural networks",
|
||||
"information retrieval and search engines",
|
||||
"database systems and data management",
|
||||
"software engineering and programming",
|
||||
"cybersecurity and network protection",
|
||||
"cloud computing and distributed systems",
|
||||
]
|
||||
|
||||
texts = []
|
||||
for i in range(n_docs):
|
||||
topic = topics[i % len(topics)]
|
||||
variation = np.random.randint(1, 100)
|
||||
text = (
|
||||
f"This is document {i} about {topic}. Content variation {variation}. "
|
||||
f"Additional information about {topic} with details and examples. "
|
||||
f"Technical discussion of {topic} including implementation aspects."
|
||||
)
|
||||
texts.append(text)
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def benchmark_backend(
|
||||
backend_name: str, texts: list[str], test_queries: list[str], backend_kwargs: dict[str, Any]
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark a specific backend with the given configuration."""
|
||||
from leann.api import LeannBuilder, LeannSearcher
|
||||
|
||||
print(f"\n🔧 Testing {backend_name.upper()} backend...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
index_path = str(Path(temp_dir) / f"benchmark_{backend_name}.leann")
|
||||
|
||||
# Build index
|
||||
print(f"📦 Building {backend_name} index with {len(texts)} documents...")
|
||||
start_time = time.time()
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend_name,
|
||||
embedding_model="facebook/contriever",
|
||||
embedding_mode="sentence-transformers",
|
||||
**backend_kwargs,
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
builder.add_text(text)
|
||||
|
||||
builder.build_index(index_path)
|
||||
build_time = time.time() - start_time
|
||||
|
||||
# Measure index size
|
||||
index_dir = Path(index_path).parent
|
||||
index_files = list(index_dir.glob(f"{Path(index_path).stem}.*"))
|
||||
total_size = sum(f.stat().st_size for f in index_files if f.is_file())
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
|
||||
print(f" ✅ Build completed in {build_time:.2f}s, index size: {size_mb:.1f}MB")
|
||||
|
||||
# Search benchmark
|
||||
print("🔍 Running search benchmark...")
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
search_times = []
|
||||
all_results = []
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
results = searcher.search(query, top_k=5)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
all_results.append(results)
|
||||
|
||||
avg_search_time = np.mean(search_times) * 1000 # Convert to ms
|
||||
print(f" ✅ Average search time: {avg_search_time:.1f}ms")
|
||||
|
||||
# Check for valid scores (detect -inf issues)
|
||||
all_scores = [
|
||||
result.score
|
||||
for results in all_results
|
||||
for result in results
|
||||
if result.score is not None
|
||||
]
|
||||
valid_scores = [
|
||||
score for score in all_scores if score != float("-inf") and score != float("inf")
|
||||
]
|
||||
score_validity_rate = len(valid_scores) / len(all_scores) if all_scores else 0
|
||||
|
||||
# Clean up (ensure embedding server shutdown and object GC)
|
||||
try:
|
||||
if hasattr(searcher, "cleanup"):
|
||||
searcher.cleanup()
|
||||
del searcher
|
||||
del builder
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Resource cleanup error: {e}")
|
||||
|
||||
return {
|
||||
"build_time": build_time,
|
||||
"avg_search_time_ms": avg_search_time,
|
||||
"index_size_mb": size_mb,
|
||||
"score_validity_rate": score_validity_rate,
|
||||
}
|
||||
|
||||
|
||||
def run_comparison(n_docs: int = 500, n_queries: int = 10):
|
||||
"""Run performance comparison between DiskANN and HNSW."""
|
||||
print("🚀 Starting DiskANN vs HNSW Performance Comparison")
|
||||
print(f"📊 Dataset: {n_docs} documents, {n_queries} test queries")
|
||||
|
||||
# Create test data
|
||||
texts = create_test_texts(n_docs)
|
||||
test_queries = [
|
||||
"machine learning algorithms",
|
||||
"natural language processing",
|
||||
"computer vision techniques",
|
||||
"data analysis methods",
|
||||
"neural network architectures",
|
||||
"database query optimization",
|
||||
"software development practices",
|
||||
"security vulnerabilities",
|
||||
"cloud infrastructure",
|
||||
"distributed computing",
|
||||
][:n_queries]
|
||||
|
||||
# HNSW benchmark
|
||||
hnsw_results = benchmark_backend(
|
||||
backend_name="hnsw",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable recompute for fair comparison
|
||||
"M": 16,
|
||||
"efConstruction": 200,
|
||||
},
|
||||
)
|
||||
|
||||
# DiskANN benchmark
|
||||
diskann_results = benchmark_backend(
|
||||
backend_name="diskann",
|
||||
texts=texts,
|
||||
test_queries=test_queries,
|
||||
backend_kwargs={
|
||||
"is_recompute": True, # Enable graph partitioning
|
||||
"num_neighbors": 32,
|
||||
"search_list_size": 50,
|
||||
},
|
||||
)
|
||||
|
||||
# Performance comparison
|
||||
print("\n📈 Performance Comparison Results")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"{'Metric':<25} {'HNSW':<15} {'DiskANN':<15} {'Speedup':<10}")
|
||||
print(f"{'-' * 60}")
|
||||
|
||||
# Build time comparison
|
||||
build_speedup = hnsw_results["build_time"] / diskann_results["build_time"]
|
||||
print(
|
||||
f"{'Build Time (s)':<25} {hnsw_results['build_time']:<15.2f} {diskann_results['build_time']:<15.2f} {build_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Search time comparison
|
||||
search_speedup = hnsw_results["avg_search_time_ms"] / diskann_results["avg_search_time_ms"]
|
||||
print(
|
||||
f"{'Search Time (ms)':<25} {hnsw_results['avg_search_time_ms']:<15.1f} {diskann_results['avg_search_time_ms']:<15.1f} {search_speedup:<10.2f}x"
|
||||
)
|
||||
|
||||
# Index size comparison
|
||||
size_ratio = diskann_results["index_size_mb"] / hnsw_results["index_size_mb"]
|
||||
print(
|
||||
f"{'Index Size (MB)':<25} {hnsw_results['index_size_mb']:<15.1f} {diskann_results['index_size_mb']:<15.1f} {size_ratio:<10.2f}x"
|
||||
)
|
||||
|
||||
# Score validity
|
||||
print(
|
||||
f"{'Score Validity (%)':<25} {hnsw_results['score_validity_rate'] * 100:<15.1f} {diskann_results['score_validity_rate'] * 100:<15.1f}"
|
||||
)
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print("\n🎯 Summary:")
|
||||
if search_speedup > 1:
|
||||
print(f" DiskANN is {search_speedup:.2f}x faster than HNSW for search")
|
||||
else:
|
||||
print(f" HNSW is {1 / search_speedup:.2f}x faster than DiskANN for search")
|
||||
|
||||
if size_ratio > 1:
|
||||
print(f" DiskANN uses {size_ratio:.2f}x more storage than HNSW")
|
||||
else:
|
||||
print(f" DiskANN uses {1 / size_ratio:.2f}x less storage than HNSW")
|
||||
|
||||
print(
|
||||
f" Both backends achieved {min(hnsw_results['score_validity_rate'], diskann_results['score_validity_rate']) * 100:.1f}% score validity"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
try:
|
||||
# Handle help request
|
||||
if len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help", "help"]:
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Usage: python {sys.argv[0]} [n_docs] [n_queries]")
|
||||
print()
|
||||
print("Arguments:")
|
||||
print(" n_docs Number of documents to index (default: 500)")
|
||||
print(" n_queries Number of test queries to run (default: 10)")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 1000")
|
||||
print(" python benchmarks/diskann_vs_hnsw_speed_comparison.py 2000 20")
|
||||
sys.exit(0)
|
||||
|
||||
# Parse command line arguments
|
||||
n_docs = int(sys.argv[1]) if len(sys.argv) > 1 else 500
|
||||
n_queries = int(sys.argv[2]) if len(sys.argv) > 2 else 10
|
||||
|
||||
print("DiskANN vs HNSW Performance Comparison")
|
||||
print("=" * 50)
|
||||
print(f"Dataset: {n_docs} documents, {n_queries} queries")
|
||||
print()
|
||||
|
||||
run_comparison(n_docs=n_docs, n_queries=n_queries)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Benchmark interrupted by user")
|
||||
sys.exit(130)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Benchmark failed: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
# Ensure clean exit (forceful to prevent rare hangs from atexit/threads)
|
||||
try:
|
||||
gc.collect()
|
||||
print("\n🧹 Cleanup completed")
|
||||
# Flush stdio to ensure message is visible before hard-exit
|
||||
try:
|
||||
import sys as _sys
|
||||
|
||||
_sys.stdout.flush()
|
||||
_sys.stderr.flush()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# Use os._exit to bypass atexit handlers that may hang in rare cases
|
||||
import os as _os
|
||||
|
||||
_os._exit(0)
|
||||
141
benchmarks/enron_emails/README.md
Normal file
141
benchmarks/enron_emails/README.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Enron Emails Benchmark
|
||||
|
||||
A comprehensive RAG benchmark for evaluating LEANN search and generation on the Enron email corpus. It mirrors the structure and CLI of the existing FinanceBench and LAION benches, using stage-based evaluation with Recall@3 and generation timing.
|
||||
|
||||
- Dataset: Enron email CSV (e.g., Kaggle wcukierski/enron-email-dataset) for passages
|
||||
- Queries: corbt/enron_emails_sample_questions (filtered for realistic questions)
|
||||
- Metrics: Recall@3 vs FAISS Flat baseline + Generation evaluation with Qwen3-8B
|
||||
|
||||
## Layout
|
||||
|
||||
benchmarks/enron_emails/
|
||||
- setup_enron_emails.py: Prepare passages, build LEANN index, build FAISS baseline
|
||||
- evaluate_enron_emails.py: Evaluate retrieval recall (Stages 2-5) + generation with Qwen3-8B
|
||||
- data/: Generated passages, queries, embeddings-related files
|
||||
- baseline/: FAISS Flat baseline files
|
||||
- llm_utils.py: LLM utilities for Qwen3-8B generation (in parent directory)
|
||||
|
||||
## Quickstart
|
||||
|
||||
1) Prepare the data and index
|
||||
|
||||
cd benchmarks/enron_emails
|
||||
python setup_enron_emails.py --data-dir data
|
||||
|
||||
Notes:
|
||||
- If `--emails-csv` is omitted, the script attempts to download from Kaggle dataset `wcukierski/enron-email-dataset` using Kaggle API (requires `KAGGLE_USERNAME` and `KAGGLE_KEY`).
|
||||
Alternatively, pass a local path to `--emails-csv`.
|
||||
|
||||
Notes:
|
||||
- The script parses emails, chunks header/body into passages, builds a compact LEANN index, and then builds a FAISS Flat baseline from the same passages and embedding model.
|
||||
- Optionally, it will also create evaluation queries from HuggingFace dataset `corbt/enron_emails_sample_questions`.
|
||||
|
||||
2) Run recall evaluation (Stage 2)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2
|
||||
|
||||
3) Complexity sweep (Stage 3)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 3 --target-recall 0.90 --max-queries 200
|
||||
|
||||
Stage 3 uses binary search over complexity to find the minimal value achieving the target Recall@3 (assumes recall is non-decreasing with complexity). The search expands the upper bound as needed and snaps complexity to multiples of 8.
|
||||
|
||||
4) Index comparison (Stage 4)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 4 --complexity 88 --max-queries 100 --output results.json
|
||||
|
||||
5) Generation evaluation (Stage 5)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 5 --complexity 88 --llm-backend hf --model-name Qwen/Qwen3-8B
|
||||
|
||||
6) Combined index + generation evaluation (Stages 4+5, recommended)
|
||||
|
||||
python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 45 --complexity 88 --llm-backend hf
|
||||
|
||||
Notes:
|
||||
- Minimal CLI: you can run from repo root with only `--index`, defaults match financebench/laion patterns:
|
||||
- `--stage` defaults to `all` (runs 2, 3, 4, 5)
|
||||
- `--baseline-dir` defaults to `baseline`
|
||||
- `--queries` defaults to `data/evaluation_queries.jsonl` (or falls back to the index directory)
|
||||
- `--llm-backend` defaults to `hf` (HuggingFace), can use `vllm`
|
||||
- `--model-name` defaults to `Qwen/Qwen3-8B`
|
||||
- Fail-fast behavior: no silent fallbacks. If compact index cannot run with recompute, it errors out.
|
||||
- Stage 5 requires Stage 4 retrieval results. Use `--stage 45` to run both efficiently.
|
||||
|
||||
Optional flags:
|
||||
- --queries data/evaluation_queries.jsonl (custom queries file)
|
||||
- --baseline-dir baseline (where FAISS baseline lives)
|
||||
- --complexity 88 (LEANN complexity parameter, optimal for 90% recall)
|
||||
- --llm-backend hf|vllm (LLM backend for generation)
|
||||
- --model-name Qwen/Qwen3-8B (LLM model for generation)
|
||||
- --max-queries 1000 (limit number of queries for evaluation)
|
||||
|
||||
## Files Produced
|
||||
- data/enron_passages_preview.jsonl: Small preview of passages used (for inspection)
|
||||
- data/enron_index_hnsw.leann.*: LEANN index files
|
||||
- baseline/faiss_flat.index + baseline/metadata.pkl: FAISS baseline with passage IDs
|
||||
- data/evaluation_queries.jsonl: Query file (id + query; includes GT IDs for reference)
|
||||
|
||||
## Notes
|
||||
- Evaluates both retrieval Recall@3 and generation timing with Qwen3-8B thinking model.
|
||||
- The emails CSV must contain a column named "message" (raw RFC822 email) and a column named "file" for source identifier. Message-ID headers are parsed as canonical message IDs when present.
|
||||
- Qwen3-8B requires special handling for thinking models with chat templates and <think></think> tag processing.
|
||||
|
||||
## Stages Summary
|
||||
|
||||
- Stage 2 (Recall@3):
|
||||
- Compares LEANN vs FAISS Flat baseline on Recall@3.
|
||||
- Compact index runs with `recompute_embeddings=True`.
|
||||
|
||||
- Stage 3 (Binary Search for Complexity):
|
||||
- Builds a non-compact index (`<index>_noncompact.leann`) and runs binary search with `recompute_embeddings=False` to find the minimal complexity achieving target Recall@3 (default 90%).
|
||||
|
||||
- Stage 4 (Index Comparison):
|
||||
- Reports .index-only sizes for compact vs non-compact.
|
||||
- Measures timings on queries by default: non-compact (no recompute) vs compact (with recompute).
|
||||
- Stores retrieval results for Stage 5 generation evaluation.
|
||||
- Fails fast if compact recompute cannot run.
|
||||
- If `--complexity` is not provided, the script tries to use the best complexity from Stage 3:
|
||||
- First from the current run (when running `--stage all`), otherwise
|
||||
- From `enron_stage3_results.json` saved next to the index during the last Stage 3 run.
|
||||
- If neither exists, Stage 4 will error and ask you to run Stage 3 or pass `--complexity`.
|
||||
|
||||
- Stage 5 (Generation Evaluation):
|
||||
- Uses Qwen3-8B thinking model for RAG generation on retrieved documents from Stage 4.
|
||||
- Supports HuggingFace (`hf`) and vLLM (`vllm`) backends.
|
||||
- Measures generation timing separately from search timing.
|
||||
- Requires Stage 4 results (no additional searching performed).
|
||||
|
||||
## Example Results
|
||||
|
||||
These are sample results obtained on Enron data using all-mpnet-base-v2 and Qwen3-8B.
|
||||
|
||||
- Stage 3 (Binary Search):
|
||||
- Minimal complexity achieving 90% Recall@3: 88
|
||||
- Sampled points:
|
||||
- C=8 → 59.9% Recall@3
|
||||
- C=72 → 89.4% Recall@3
|
||||
- C=88 → 90.2% Recall@3
|
||||
- C=96 → 90.7% Recall@3
|
||||
- C=112 → 91.1% Recall@3
|
||||
- C=136 → 91.3% Recall@3
|
||||
- C=256 → 92.0% Recall@3
|
||||
|
||||
- Stage 4 (Index Sizes, .index only):
|
||||
- Compact: ~2.2 MB
|
||||
- Non-compact: ~82.0 MB
|
||||
- Storage saving by compact: ~97.3%
|
||||
|
||||
- Stage 4 (Search Timing, 988 queries, complexity=88):
|
||||
- Non-compact (no recompute): ~0.0075 s avg per query
|
||||
- Compact (with recompute): ~1.981 s avg per query
|
||||
- Speed ratio (non-compact/compact): ~0.0038x
|
||||
|
||||
- Stage 5 (RAG Generation, 988 queries, Qwen3-8B):
|
||||
- Average generation time: ~22.302 s per query
|
||||
- Total queries processed: 988
|
||||
- LLM backend: HuggingFace transformers
|
||||
- Model: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||
|
||||
Full JSON output is saved by the script (see `--output`), e.g.:
|
||||
`benchmarks/enron_emails/results_enron_stage45.json`.
|
||||
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
1
benchmarks/enron_emails/data/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
downloads/
|
||||
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
614
benchmarks/enron_emails/evaluate_enron_emails.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""
|
||||
Enron Emails Benchmark Evaluation - Retrieval Recall@3 (Stages 2/3/4)
|
||||
Follows the style of FinanceBench/LAION: Stage 2 recall vs FAISS baseline,
|
||||
Stage 3 complexity sweep to target recall, Stage 4 index comparison.
|
||||
On errors, fail fast without fallbacks.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.passage_ids = pickle.load(f)
|
||||
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||
|
||||
# No fallbacks here; if embedding server is needed but fails, the caller will see the error.
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 using FAISS Flat as ground truth"""
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
for i, query in enumerate(queries):
|
||||
# Compute query embedding with the same model/mode as the index
|
||||
q_emb = compute_embeddings(
|
||||
[query],
|
||||
self.searcher.embedding_model,
|
||||
mode=self.searcher.embedding_mode,
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS Flat ground truth
|
||||
n = q_emb.shape[0]
|
||||
k = 3
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(q_emb),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN (may require embedding server depending on index configuration)
|
||||
results = self.searcher.search(
|
||||
query,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {r.id for r in results}
|
||||
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0
|
||||
total_recall += recall
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: '{query[:60]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS: {list(baseline_ids)}")
|
||||
print(f" LEANN: {list(test_ids)}")
|
||||
print(f" ∩: {list(intersection)}")
|
||||
|
||||
avg = total_recall / max(1, len(queries))
|
||||
print(f"📊 Average Recall@3: {avg:.3f} ({avg * 100:.1f}%)")
|
||||
return avg
|
||||
|
||||
def cleanup(self):
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class EnronEvaluator:
|
||||
def __init__(self, index_path: str):
|
||||
self.index_path = index_path
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
def load_queries(self, queries_file: str) -> list[str]:
|
||||
queries: list[str] = []
|
||||
with open(queries_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
data = json.loads(line)
|
||||
if "query" in data:
|
||||
queries.append(data["query"])
|
||||
print(f"📊 Loaded {len(queries)} queries from {queries_file}")
|
||||
return queries
|
||||
|
||||
def cleanup(self):
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes (.index only), similar to LAION bench."""
|
||||
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem
|
||||
|
||||
sizes: dict[str, float] = {}
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json"
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl"
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx"
|
||||
|
||||
sizes["index_only_mb"] = (
|
||||
index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||
)
|
||||
sizes["metadata_mb"] = (
|
||||
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_text_mb"] = (
|
||||
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_index_mb"] = (
|
||||
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||
)
|
||||
|
||||
print(f" 📁 .index size: {sizes['index_only_mb']:.1f} MB")
|
||||
return sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison using current passages and embeddings."""
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source and embedding model
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
# Load all passages and ids
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
ids.append(str(data["id"]))
|
||||
texts.append(data["text"])
|
||||
|
||||
# Compute embeddings using the same method as LEANN
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
embeddings = compute_embeddings(
|
||||
texts,
|
||||
meta["embedding_model"],
|
||||
mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Build non-compact index with same passages and embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=False,
|
||||
is_compact=False,
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact"]
|
||||
},
|
||||
)
|
||||
|
||||
# Persist a pickle for build_index_from_embeddings
|
||||
pkl_path = current_index_dir / f"{Path(non_compact_index_path).stem}_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings), pf)
|
||||
|
||||
print(
|
||||
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||
)
|
||||
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = EnronEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_queries: list[str], complexity: int
|
||||
) -> dict:
|
||||
"""Compare search speed for non-compact vs compact indexes."""
|
||||
import time
|
||||
|
||||
results: dict = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
"retrieval_results": [], # Store retrieval results for Stage 5
|
||||
}
|
||||
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
# Non-compact (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
results["non_compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Compact (with recompute). Fail fast if it cannot run.
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
for q in test_queries:
|
||||
t0 = time.time()
|
||||
docs = compact_searcher.search(
|
||||
q, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
results["compact"]["search_times"].append(time.time() - t0)
|
||||
|
||||
# Store retrieval results for Stage 5
|
||||
results["retrieval_results"].append(
|
||||
{"query": q, "retrieved_docs": [{"id": doc.id, "text": doc.text} for doc in docs]}
|
||||
)
|
||||
compact_searcher.cleanup()
|
||||
|
||||
if results["non_compact"]["search_times"]:
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
if results["compact"]["search_times"]:
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
if results["avg_search_times"].get("compact", 0) > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = 0.0
|
||||
|
||||
non_compact_searcher.cleanup()
|
||||
return results
|
||||
|
||||
def evaluate_complexity(
|
||||
self,
|
||||
recall_eval: "RecallEvaluator",
|
||||
queries: list[str],
|
||||
target: float = 0.90,
|
||||
c_min: int = 8,
|
||||
c_max: int = 256,
|
||||
max_iters: int = 10,
|
||||
recompute: bool = False,
|
||||
) -> dict:
|
||||
"""Binary search minimal complexity achieving target recall (monotonic assumption)."""
|
||||
|
||||
def round_c(x: int) -> int:
|
||||
# snap to multiple of 8 like other benches typically do
|
||||
return max(1, int((x + 7) // 8) * 8)
|
||||
|
||||
metrics: list[dict] = []
|
||||
|
||||
lo = round_c(c_min)
|
||||
hi = round_c(c_max)
|
||||
|
||||
print(
|
||||
f"🧪 Binary search complexity in [{lo}, {hi}] for target Recall@3>={int(target * 100)}%..."
|
||||
)
|
||||
|
||||
# Ensure upper bound can reach target; expand if needed (up to a cap)
|
||||
r_lo = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=lo, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": lo, "recall_at_3": r_lo})
|
||||
r_hi = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=hi, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||
|
||||
cap = 1024
|
||||
while r_hi < target and hi < cap:
|
||||
lo = hi
|
||||
r_lo = r_hi
|
||||
hi = round_c(hi * 2)
|
||||
r_hi = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=hi, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": hi, "recall_at_3": r_hi})
|
||||
|
||||
if r_hi < target:
|
||||
print(f"⚠️ Max complexity {hi} did not reach target recall {target:.2f}.")
|
||||
print("📈 Observations:")
|
||||
for m in metrics:
|
||||
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||
return {"metrics": metrics, "best_complexity": None, "target_recall": target}
|
||||
|
||||
# Binary search within [lo, hi]
|
||||
best = hi
|
||||
iters = 0
|
||||
while lo < hi and iters < max_iters:
|
||||
mid = round_c((lo + hi) // 2)
|
||||
r_mid = recall_eval.evaluate_recall_at_3(
|
||||
queries, complexity=mid, recompute_embeddings=recompute
|
||||
)
|
||||
metrics.append({"complexity": mid, "recall_at_3": r_mid})
|
||||
if r_mid >= target:
|
||||
best = mid
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid + 8 # move past mid, respecting multiple-of-8 step
|
||||
iters += 1
|
||||
|
||||
print("📈 Binary search results (sampled points):")
|
||||
# Print unique complexity entries ordered by complexity
|
||||
for m in sorted(
|
||||
{m["complexity"]: m for m in metrics}.values(), key=lambda x: x["complexity"]
|
||||
):
|
||||
print(f" C={m['complexity']:>4} -> Recall@3={m['recall_at_3'] * 100:.1f}%")
|
||||
print(f"✅ Minimal complexity achieving {int(target * 100)}% recall: {best}")
|
||||
return {"metrics": metrics, "best_complexity": best, "target_recall": target}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Enron Emails Benchmark Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument(
|
||||
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "5", "all", "45"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="LEANN search complexity")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument(
|
||||
"--max-queries", type=int, help="Limit number of queries to evaluate", default=1000
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target-recall", type=float, default=0.90, help="Target Recall@3 for Stage 3"
|
||||
)
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument("--llm-backend", choices=["hf", "vllm"], default="hf", help="LLM backend")
|
||||
parser.add_argument("--model-name", default="Qwen/Qwen3-8B", help="Model name")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve queries file: if default path not found, fall back to index's directory
|
||||
if not os.path.exists(args.queries):
|
||||
from pathlib import Path
|
||||
|
||||
idx_dir = Path(args.index).parent
|
||||
fallback_q = idx_dir / "evaluation_queries.jsonl"
|
||||
if fallback_q.exists():
|
||||
args.queries = str(fallback_q)
|
||||
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_enron_emails.py first to build the baseline")
|
||||
raise SystemExit(1)
|
||||
|
||||
results_out: dict = {}
|
||||
|
||||
if args.stage in ("2", "all"):
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
queries = queries[:10]
|
||||
print(f"🧪 Using first {len(queries)} queries")
|
||||
|
||||
complexity = args.complexity or 64
|
||||
r = evaluator.evaluate_recall_at_3(queries, complexity)
|
||||
results_out["stage2"] = {"complexity": complexity, "recall_at_3": r}
|
||||
evaluator.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
if args.stage in ("3", "all"):
|
||||
print("🚀 Starting Stage 3: Binary search for target recall (no recompute)")
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
queries = queries[: args.max_queries]
|
||||
print(f"🧪 Using first {len(queries)} queries")
|
||||
|
||||
# Build non-compact index for fast binary search (recompute_embeddings=False)
|
||||
from pathlib import Path
|
||||
|
||||
index_path = Path(args.index)
|
||||
non_compact_index_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||
enron_eval.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact evaluator for binary search with recompute=False
|
||||
evaluator_nc = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
sweep = enron_eval.evaluate_complexity(
|
||||
evaluator_nc, queries, target=args.target_recall, recompute=False
|
||||
)
|
||||
results_out["stage3"] = sweep
|
||||
# Persist default stage 3 results near the index for Stage 4 auto-pickup
|
||||
from pathlib import Path
|
||||
|
||||
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||
with open(default_stage3_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"stage3": sweep}, f, indent=2)
|
||||
print(f"📝 Saved Stage 3 summary to {default_stage3_path}")
|
||||
evaluator_nc.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 3 completed!\n")
|
||||
|
||||
if args.stage in ("4", "all", "45"):
|
||||
print("🚀 Starting Stage 4: Index size + performance comparison")
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
enron_eval = EnronEvaluator(args.index)
|
||||
queries = enron_eval.load_queries(args.queries)
|
||||
test_q = queries[: min(args.max_queries, len(queries))]
|
||||
|
||||
current_sizes = enron_eval.analyze_index_sizes()
|
||||
# Build non-compact index for comparison (no fallback)
|
||||
from pathlib import Path
|
||||
|
||||
index_path = Path(args.index)
|
||||
non_compact_path = str(index_path.parent / f"{index_path.stem}_noncompact.leann")
|
||||
non_compact_sizes = enron_eval.create_non_compact_index_for_comparison(non_compact_path)
|
||||
nc_eval = EnronEvaluator(non_compact_path)
|
||||
|
||||
if (
|
||||
current_sizes.get("index_only_mb", 0) > 0
|
||||
and non_compact_sizes.get("index_only_mb", 0) > 0
|
||||
):
|
||||
storage_saving_percent = max(
|
||||
0.0,
|
||||
100.0 * (1.0 - current_sizes["index_only_mb"] / non_compact_sizes["index_only_mb"]),
|
||||
)
|
||||
else:
|
||||
storage_saving_percent = 0.0
|
||||
|
||||
if args.complexity is None:
|
||||
# Prefer in-session Stage 3 result
|
||||
if "stage3" in results_out and results_out["stage3"].get("best_complexity") is not None:
|
||||
complexity = results_out["stage3"]["best_complexity"]
|
||||
print(f"📥 Using best complexity from Stage 3 in-session: {complexity}")
|
||||
else:
|
||||
# Try to load last saved Stage 3 result near index
|
||||
default_stage3_path = Path(args.index).parent / "enron_stage3_results.json"
|
||||
if default_stage3_path.exists():
|
||||
with open(default_stage3_path, encoding="utf-8") as f:
|
||||
prev = json.load(f)
|
||||
complexity = prev.get("stage3", {}).get("best_complexity")
|
||||
if complexity is None:
|
||||
raise SystemExit(
|
||||
"❌ Stage 4: No --complexity and no best_complexity found in saved Stage 3 results"
|
||||
)
|
||||
print(f"📥 Using best complexity from saved Stage 3: {complexity}")
|
||||
else:
|
||||
raise SystemExit(
|
||||
"❌ Stage 4 requires --complexity if Stage 3 hasn't been run. Run stage 3 first or pass --complexity."
|
||||
)
|
||||
else:
|
||||
complexity = args.complexity
|
||||
|
||||
comp = enron_eval.compare_index_performance(
|
||||
non_compact_path, args.index, test_q, complexity=complexity
|
||||
)
|
||||
results_out["stage4"] = {
|
||||
"current_index": current_sizes,
|
||||
"non_compact_index": non_compact_sizes,
|
||||
"storage_saving_percent": storage_saving_percent,
|
||||
"performance_comparison": comp,
|
||||
}
|
||||
nc_eval.cleanup()
|
||||
evaluator.cleanup()
|
||||
enron_eval.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Generation evaluation with Qwen3-8B")
|
||||
|
||||
# Check if Stage 4 results exist
|
||||
if "stage4" not in results_out or "performance_comparison" not in results_out["stage4"]:
|
||||
print("❌ Stage 5 requires Stage 4 retrieval results")
|
||||
print("💡 Run Stage 4 first or use --stage all")
|
||||
raise SystemExit(1)
|
||||
|
||||
retrieval_results = results_out["stage4"]["performance_comparison"]["retrieval_results"]
|
||||
if not retrieval_results:
|
||||
print("❌ No retrieval results found from Stage 4")
|
||||
raise SystemExit(1)
|
||||
|
||||
print(f"📁 Using {len(retrieval_results)} retrieval results from Stage 4")
|
||||
|
||||
# Load LLM
|
||||
try:
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Run generation using stored retrieval results
|
||||
import time
|
||||
|
||||
from llm_utils import create_prompt
|
||||
|
||||
generation_times = []
|
||||
responses = []
|
||||
|
||||
print("🤖 Running generation on pre-retrieved results...")
|
||||
for i, item in enumerate(retrieval_results):
|
||||
query = item["query"]
|
||||
retrieved_docs = item["retrieved_docs"]
|
||||
|
||||
# Prepare context from retrieved docs
|
||||
context = "\n\n".join([doc["text"] for doc in retrieved_docs])
|
||||
prompt = create_prompt(context, query, "emails")
|
||||
|
||||
# Time generation only
|
||||
gen_start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - gen_start
|
||||
|
||||
generation_times.append(gen_time)
|
||||
responses.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f" Q{i + 1}: Gen={gen_time:.3f}s")
|
||||
|
||||
avg_gen_time = sum(generation_times) / len(generation_times)
|
||||
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Queries: {len(retrieval_results)}")
|
||||
print(f" Avg Generation Time: {avg_gen_time:.3f}s")
|
||||
print(" (Search time from Stage 4)")
|
||||
|
||||
results_out["stage5"] = {
|
||||
"total_queries": len(retrieval_results),
|
||||
"avg_generation_time": avg_gen_time,
|
||||
"generation_times": generation_times,
|
||||
"responses": responses,
|
||||
}
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Results:")
|
||||
for i in range(min(3, len(retrieval_results))):
|
||||
query = retrieval_results[i]["query"]
|
||||
response = responses[i]
|
||||
print(f" Q{i + 1}: {query[:60]}...")
|
||||
print(f" A{i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers/vllm is installed and model is available")
|
||||
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.output and results_out:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(results_out, f, indent=2)
|
||||
print(f"📝 Saved results to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
359
benchmarks/enron_emails/setup_enron_emails.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
Enron Emails Benchmark Setup Script
|
||||
Prepares passages from emails.csv, builds LEANN index, and FAISS Flat baseline
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from email import message_from_string
|
||||
from email.policy import default
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
|
||||
class EnronSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.passages_preview = self.data_dir / "enron_passages_preview.jsonl"
|
||||
self.index_path = self.data_dir / "enron_index_hnsw.leann"
|
||||
self.queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||
self.downloads_dir = self.data_dir / "downloads"
|
||||
self.downloads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ----------------------------
|
||||
# Dataset acquisition
|
||||
# ----------------------------
|
||||
def ensure_emails_csv(self, emails_csv: Optional[str]) -> str:
|
||||
"""Return a path to emails.csv, downloading from Kaggle if needed."""
|
||||
if emails_csv:
|
||||
p = Path(emails_csv)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"emails.csv not found: {emails_csv}")
|
||||
return str(p)
|
||||
|
||||
print(
|
||||
"📥 Trying to download Enron emails.csv from Kaggle (wcukierski/enron-email-dataset)..."
|
||||
)
|
||||
try:
|
||||
from kaggle.api.kaggle_api_extended import KaggleApi
|
||||
|
||||
api = KaggleApi()
|
||||
api.authenticate()
|
||||
api.dataset_download_files(
|
||||
"wcukierski/enron-email-dataset", path=str(self.downloads_dir), unzip=True
|
||||
)
|
||||
candidate = self.downloads_dir / "emails.csv"
|
||||
if candidate.exists():
|
||||
print(f"✅ Downloaded emails.csv: {candidate}")
|
||||
return str(candidate)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"emails.csv was not found in {self.downloads_dir} after Kaggle download"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"❌ Could not download via Kaggle automatically. Provide --emails-csv or configure Kaggle API."
|
||||
)
|
||||
print(
|
||||
" Set KAGGLE_USERNAME and KAGGLE_KEY env vars, or place emails.csv locally and pass --emails-csv."
|
||||
)
|
||||
raise e
|
||||
|
||||
# ----------------------------
|
||||
# Data preparation
|
||||
# ----------------------------
|
||||
@staticmethod
|
||||
def _extract_message_id(raw_email: str) -> str:
|
||||
msg = message_from_string(raw_email, policy=default)
|
||||
val = msg.get("Message-ID", "")
|
||||
if val.startswith("<") and val.endswith(">"):
|
||||
val = val[1:-1]
|
||||
return val or ""
|
||||
|
||||
@staticmethod
|
||||
def _split_header_body(raw_email: str) -> tuple[str, str]:
|
||||
parts = raw_email.split("\n\n", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[0].strip(), parts[1].strip()
|
||||
# Heuristic fallback
|
||||
first_lines = raw_email.splitlines()
|
||||
if first_lines and ":" in first_lines[0]:
|
||||
return raw_email.strip(), ""
|
||||
return "", raw_email.strip()
|
||||
|
||||
@staticmethod
|
||||
def _split_fixed_words(text: str, chunk_words: int, keep_last: bool) -> list[str]:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
if chunk_words <= 0:
|
||||
return [text]
|
||||
words = text.split()
|
||||
if not words:
|
||||
return []
|
||||
limit = len(words)
|
||||
if not keep_last:
|
||||
limit = (len(words) // chunk_words) * chunk_words
|
||||
if limit == 0:
|
||||
return []
|
||||
chunks = [" ".join(words[i : i + chunk_words]) for i in range(0, limit, chunk_words)]
|
||||
return [c for c in (s.strip() for s in chunks) if c]
|
||||
|
||||
def _iter_passages_from_csv(
|
||||
self,
|
||||
emails_csv: Path,
|
||||
chunk_words: int = 256,
|
||||
keep_last_header: bool = True,
|
||||
keep_last_body: bool = True,
|
||||
max_emails: int | None = None,
|
||||
) -> Iterable[dict]:
|
||||
with open(emails_csv, encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
count = 0
|
||||
for i, row in enumerate(reader):
|
||||
if max_emails is not None and count >= max_emails:
|
||||
break
|
||||
|
||||
raw_message = row.get("message", "")
|
||||
email_file_id = row.get("file", "")
|
||||
|
||||
if not raw_message.strip():
|
||||
continue
|
||||
|
||||
message_id = self._extract_message_id(raw_message)
|
||||
if not message_id:
|
||||
# Fallback ID based on CSV position and file path
|
||||
safe_file = re.sub(r"[^A-Za-z0-9_.-]", "_", email_file_id)
|
||||
message_id = f"enron_{i}_{safe_file}"
|
||||
|
||||
header, body = self._split_header_body(raw_message)
|
||||
|
||||
# Header chunks
|
||||
for chunk in self._split_fixed_words(header, chunk_words, keep_last_header):
|
||||
yield {
|
||||
"text": chunk,
|
||||
"metadata": {
|
||||
"message_id": message_id,
|
||||
"is_header": True,
|
||||
"email_file_id": email_file_id,
|
||||
},
|
||||
}
|
||||
|
||||
# Body chunks
|
||||
for chunk in self._split_fixed_words(body, chunk_words, keep_last_body):
|
||||
yield {
|
||||
"text": chunk,
|
||||
"metadata": {
|
||||
"message_id": message_id,
|
||||
"is_header": False,
|
||||
"email_file_id": email_file_id,
|
||||
},
|
||||
}
|
||||
|
||||
count += 1
|
||||
|
||||
# ----------------------------
|
||||
# Build LEANN index and FAISS baseline
|
||||
# ----------------------------
|
||||
def build_leann_index(
|
||||
self,
|
||||
emails_csv: Optional[str],
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
chunk_words: int = 256,
|
||||
max_emails: int | None = None,
|
||||
) -> str:
|
||||
emails_csv_path = self.ensure_emails_csv(emails_csv)
|
||||
print(f"🏗️ Building LEANN index from {emails_csv_path}...")
|
||||
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=True,
|
||||
is_compact=True,
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Stream passages and add to builder
|
||||
preview_written = 0
|
||||
with open(self.passages_preview, "w", encoding="utf-8") as preview_out:
|
||||
for p in self._iter_passages_from_csv(
|
||||
Path(emails_csv_path), chunk_words=chunk_words, max_emails=max_emails
|
||||
):
|
||||
builder.add_text(p["text"], metadata=p["metadata"])
|
||||
if preview_written < 200:
|
||||
preview_out.write(json.dumps({"text": p["text"][:200], **p["metadata"]}) + "\n")
|
||||
preview_written += 1
|
||||
|
||||
print(f"🔨 Building index at {self.index_path}...")
|
||||
builder.build_index(str(self.index_path))
|
||||
print("✅ LEANN index built!")
|
||||
return str(self.index_path)
|
||||
|
||||
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline") -> str:
|
||||
print("🔨 Building FAISS Flat baseline from LEANN passages...")
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from leann.api import compute_embeddings
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Read meta for passage source and embedding model
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
embedding_model = meta["embedding_model"]
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
if not os.path.isabs(passage_file):
|
||||
index_dir = os.path.dirname(index_path)
|
||||
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||
|
||||
# Load passages from builder output so IDs match LEANN
|
||||
passages: list[str] = []
|
||||
passage_ids: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"]) # builder-assigned ID
|
||||
|
||||
print(f"📄 Loaded {len(passages)} passages for baseline")
|
||||
print(f"🤖 Embedding model: {embedding_model}")
|
||||
|
||||
embeddings = compute_embeddings(
|
||||
passages,
|
||||
embedding_model,
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
)
|
||||
|
||||
# Build FAISS IndexFlatIP
|
||||
dim = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dim)
|
||||
emb_f32 = embeddings.astype(np.float32)
|
||||
index.add(emb_f32.shape[0], faiss.swig_ptr(emb_f32))
|
||||
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as pf:
|
||||
pickle.dump(passage_ids, pf)
|
||||
|
||||
print(f"✅ FAISS baseline saved: {baseline_path}")
|
||||
print(f"✅ Metadata saved: {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
return baseline_path
|
||||
|
||||
# ----------------------------
|
||||
# Queries (optional): prepare evaluation queries file
|
||||
# ----------------------------
|
||||
def prepare_queries(self, min_realism: float = 0.85) -> Path:
|
||||
print(
|
||||
"📝 Preparing evaluation queries from HuggingFace dataset corbt/enron_emails_sample_questions ..."
|
||||
)
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("corbt/enron_emails_sample_questions", split="train")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to load dataset: {e}")
|
||||
return self.queries_file
|
||||
|
||||
kept = 0
|
||||
with open(self.queries_file, "w", encoding="utf-8") as out:
|
||||
for i, item in enumerate(ds):
|
||||
how_realistic = float(item.get("how_realistic", 0.0))
|
||||
if how_realistic < min_realism:
|
||||
continue
|
||||
qid = str(item.get("id", f"enron_q_{i}"))
|
||||
query = item.get("question", "")
|
||||
if not query:
|
||||
continue
|
||||
record = {
|
||||
"id": qid,
|
||||
"query": query,
|
||||
# For reference only, not used in recall metric below
|
||||
"gt_message_ids": item.get("message_ids", []),
|
||||
}
|
||||
out.write(json.dumps(record) + "\n")
|
||||
kept += 1
|
||||
print(f"✅ Wrote {kept} queries to {self.queries_file}")
|
||||
return self.queries_file
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup Enron Emails Benchmark")
|
||||
parser.add_argument(
|
||||
"--emails-csv",
|
||||
help="Path to emails.csv (Enron dataset). If omitted, attempt Kaggle download.",
|
||||
)
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument("--backend", choices=["hnsw", "diskann"], default="hnsw")
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model for LEANN",
|
||||
)
|
||||
parser.add_argument("--chunk-words", type=int, default=256, help="Fixed word chunk size")
|
||||
parser.add_argument("--max-emails", type=int, help="Limit number of emails to process")
|
||||
parser.add_argument("--skip-queries", action="store_true", help="Skip creating queries file")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip building LEANN index")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
setup = EnronSetup(args.data_dir)
|
||||
|
||||
# Build index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
emails_csv=args.emails_csv,
|
||||
backend=args.backend,
|
||||
embedding_model=args.embedding_model,
|
||||
chunk_words=args.chunk_words,
|
||||
max_emails=args.max_emails,
|
||||
)
|
||||
|
||||
# Build FAISS baseline from the same passages & embeddings
|
||||
setup.build_faiss_flat_baseline(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping LEANN index build and baseline")
|
||||
|
||||
# Queries file (optional)
|
||||
if not args.skip_queries:
|
||||
setup.prepare_queries()
|
||||
else:
|
||||
print("⏭️ Skipping query preparation")
|
||||
|
||||
print("\n🎉 Enron Emails setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("Next steps:")
|
||||
print(
|
||||
"1) Evaluate recall: python evaluate_enron_emails.py --index data/enron_index_hnsw.leann --stage 2"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test only Faiss HNSW"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import psutil
|
||||
import gc
|
||||
import os
|
||||
|
||||
|
||||
def get_memory_usage():
|
||||
@@ -37,20 +37,20 @@ def main():
|
||||
import faiss
|
||||
except ImportError:
|
||||
print("Faiss is not installed.")
|
||||
print("Please install it with `uv pip install faiss-cpu` and you can then run this script again")
|
||||
print(
|
||||
"Please install it with `uv pip install faiss-cpu` and you can then run this script again"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
from llama_index.core import (
|
||||
SimpleDirectoryReader,
|
||||
VectorStoreIndex,
|
||||
StorageContext,
|
||||
Settings,
|
||||
node_parser,
|
||||
Document,
|
||||
SimpleDirectoryReader,
|
||||
StorageContext,
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
tracker = MemoryTracker("Faiss HNSW")
|
||||
tracker.checkpoint("Initial")
|
||||
@@ -65,7 +65,7 @@ def main():
|
||||
tracker.checkpoint("After Faiss index creation")
|
||||
|
||||
documents = SimpleDirectoryReader(
|
||||
"examples/data",
|
||||
"data",
|
||||
recursive=True,
|
||||
encoding="utf-8",
|
||||
required_exts=[".pdf", ".txt", ".md"],
|
||||
@@ -90,8 +90,9 @@ def main():
|
||||
vector_store=vector_store, persist_dir="./storage_faiss"
|
||||
)
|
||||
from llama_index.core import load_index_from_storage
|
||||
|
||||
index = load_index_from_storage(storage_context=storage_context)
|
||||
print(f"Index loaded from ./storage_faiss")
|
||||
print("Index loaded from ./storage_faiss")
|
||||
tracker.checkpoint("After loading existing index")
|
||||
index_loaded = True
|
||||
except Exception as e:
|
||||
@@ -99,19 +100,18 @@ def main():
|
||||
print("Cleaning up corrupted index and building new one...")
|
||||
# Clean up corrupted index
|
||||
import shutil
|
||||
|
||||
if os.path.exists("./storage_faiss"):
|
||||
shutil.rmtree("./storage_faiss")
|
||||
|
||||
|
||||
if not index_loaded:
|
||||
print("Building new Faiss HNSW index...")
|
||||
|
||||
|
||||
# Use the correct Faiss building pattern from the example
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context,
|
||||
transformations=[node_parser]
|
||||
documents, storage_context=storage_context, transformations=[node_parser]
|
||||
)
|
||||
tracker.checkpoint("After index building")
|
||||
|
||||
@@ -124,10 +124,10 @@ def main():
|
||||
runtime_start_mem = get_memory_usage()
|
||||
print(f"Before load memory: {runtime_start_mem:.1f} MB")
|
||||
tracker.checkpoint("Before load memory")
|
||||
|
||||
|
||||
query_engine = index.as_query_engine(similarity_top_k=20)
|
||||
queries = [
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"什么是盘古大模型以及盘古开发过程中遇到了什么阴暗面,任务令一般在什么城市颁发",
|
||||
"What is LEANN and how does it work?",
|
||||
"华为诺亚方舟实验室的主要研究内容",
|
||||
]
|
||||
@@ -141,7 +141,7 @@ def main():
|
||||
|
||||
runtime_end_mem = get_memory_usage()
|
||||
runtime_overhead = runtime_end_mem - runtime_start_mem
|
||||
|
||||
|
||||
peak_memory = tracker.summary()
|
||||
print(f"Peak Memory: {peak_memory:.1f} MB")
|
||||
print(f"Runtime Memory Overhead: {runtime_overhead:.1f} MB")
|
||||
115
benchmarks/financebench/README.md
Normal file
115
benchmarks/financebench/README.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# FinanceBench Benchmark for LEANN-RAG
|
||||
|
||||
FinanceBench is a benchmark for evaluating retrieval-augmented generation (RAG) systems on financial document question-answering tasks.
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Source**: [PatronusAI/financebench](https://huggingface.co/datasets/PatronusAI/financebench)
|
||||
- **Questions**: 150 financial Q&A examples
|
||||
- **Documents**: 368 PDF files (10-K, 10-Q, 8-K, earnings reports)
|
||||
- **Companies**: Major public companies (3M, Apple, Microsoft, Amazon, etc.)
|
||||
- **Paper**: [FinanceBench: A New Benchmark for Financial Question Answering](https://arxiv.org/abs/2311.11944)
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
benchmarks/financebench/
|
||||
├── setup_financebench.py # Downloads PDFs and builds index
|
||||
├── evaluate_financebench.py # Intelligent evaluation script
|
||||
├── data/
|
||||
│ ├── financebench_merged.jsonl # Q&A dataset
|
||||
│ ├── pdfs/ # Downloaded financial documents
|
||||
│ └── index/ # LEANN indexes
|
||||
│ └── financebench_full_hnsw.leann
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Setup (Download & Build Index)
|
||||
|
||||
```bash
|
||||
cd benchmarks/financebench
|
||||
python setup_financebench.py
|
||||
```
|
||||
|
||||
This will:
|
||||
- Download the 150 Q&A examples
|
||||
- Download all 368 PDF documents (parallel processing)
|
||||
- Build a LEANN index from 53K+ text chunks
|
||||
- Verify setup with test query
|
||||
|
||||
### 2. Evaluation
|
||||
|
||||
```bash
|
||||
# Basic retrieval evaluation
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann
|
||||
|
||||
|
||||
# RAG generation evaluation with Qwen3-8B
|
||||
python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann --stage 4 --complexity 64 --llm-backend hf --model-name Qwen/Qwen3-8B --output results_qwen3.json
|
||||
```
|
||||
|
||||
## Evaluation Methods
|
||||
|
||||
### Retrieval Evaluation
|
||||
Uses intelligent matching with three strategies:
|
||||
1. **Exact text overlap** - Direct substring matches
|
||||
2. **Number matching** - Key financial figures ($1,577, 1.2B, etc.)
|
||||
3. **Semantic similarity** - Word overlap with 20% threshold
|
||||
|
||||
### QA Evaluation
|
||||
LLM-based answer evaluation using GPT-4o:
|
||||
- Handles numerical rounding and equivalent representations
|
||||
- Considers fractions, percentages, and decimal equivalents
|
||||
- Evaluates semantic meaning rather than exact text match
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### LEANN-RAG Performance (sentence-transformers/all-mpnet-base-v2)
|
||||
|
||||
**Retrieval Metrics:**
|
||||
- **Question Coverage**: 100.0% (all questions retrieve relevant docs)
|
||||
- **Exact Match Rate**: 0.7% (substring overlap with evidence)
|
||||
- **Number Match Rate**: 120.7% (key financial figures matched)*
|
||||
- **Semantic Match Rate**: 4.7% (word overlap ≥20%)
|
||||
- **Average Search Time**: 0.097s
|
||||
|
||||
**QA Metrics:**
|
||||
- **Accuracy**: 42.7% (LLM-evaluated answer correctness)
|
||||
- **Average QA Time**: 4.71s (end-to-end response time)
|
||||
|
||||
**System Performance:**
|
||||
- **Index Size**: 53,985 chunks from 368 PDFs
|
||||
- **Build Time**: ~5-10 minutes with sentence-transformers/all-mpnet-base-v2
|
||||
|
||||
*Note: Number match rate >100% indicates multiple retrieved documents contain the same financial figures, which is expected behavior for financial data appearing across multiple document sections.
|
||||
|
||||
### LEANN-RAG Generation Performance (Qwen3-8B)
|
||||
|
||||
- **Stage 4 (Index Comparison):**
|
||||
- Compact Index: 5.0 MB
|
||||
- Non-compact Index: 172.2 MB
|
||||
- **Storage Saving**: 97.1%
|
||||
- **Search Performance**:
|
||||
- Non-compact (no recompute): 0.009s avg per query
|
||||
- Compact (with recompute): 2.203s avg per query
|
||||
- Speed ratio: 0.004x
|
||||
|
||||
**Generation Evaluation (20 queries, complexity=64):**
|
||||
- **Average Search Time**: 1.638s per query
|
||||
- **Average Generation Time**: 45.957s per query
|
||||
- **LLM Backend**: HuggingFace transformers
|
||||
- **Model**: Qwen/Qwen3-8B (thinking model with <think></think> processing)
|
||||
- **Total Questions Processed**: 20
|
||||
|
||||
## Options
|
||||
|
||||
```bash
|
||||
# Use different backends
|
||||
python setup_financebench.py --backend diskann
|
||||
python evaluate_financebench.py --index data/index/financebench_full_diskann.leann
|
||||
|
||||
# Use different embedding models
|
||||
python setup_financebench.py --embedding-model facebook/contriever
|
||||
```
|
||||
923
benchmarks/financebench/evaluate_financebench.py
Executable file
923
benchmarks/financebench/evaluate_financebench.py
Executable file
@@ -0,0 +1,923 @@
|
||||
"""
|
||||
FinanceBench Evaluation Script - Modular Recall-based Evaluation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
from leann import LeannChat, LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
from ..llm_utils import evaluate_rag, generate_hf, generate_vllm, load_hf_model, load_vllm_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (searcher vs baseline)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
# Load FAISS flat baseline
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.passage_ids = pickle.load(f)
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} vectors")
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, queries: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 for given queries at specified complexity"""
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = len(queries)
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
# Get ground truth: search with FAISS flat
|
||||
from leann.api import compute_embeddings
|
||||
|
||||
query_embedding = compute_embeddings(
|
||||
[query],
|
||||
self.searcher.embedding_model,
|
||||
mode=self.searcher.embedding_mode,
|
||||
use_server=False,
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||
n = query_embedding.shape[0] # Number of queries
|
||||
k = 3 # Number of nearest neighbors
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(query_embedding),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
# Extract the results
|
||||
baseline_ids = {self.passage_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN at specified complexity
|
||||
test_results = self.searcher.search(
|
||||
query,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {result.id for result in test_results}
|
||||
|
||||
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: '{query[:50]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||
return avg_recall
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class FinanceBenchEvaluator:
|
||||
def __init__(self, index_path: str, openai_api_key: Optional[str] = None):
|
||||
self.index_path = index_path
|
||||
self.openai_client = openai.OpenAI(api_key=openai_api_key) if openai_api_key else None
|
||||
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
self.chat = LeannChat(index_path) if openai_api_key else None
|
||||
|
||||
def load_dataset(self, dataset_path: str = "data/financebench_merged.jsonl"):
|
||||
"""Load FinanceBench dataset"""
|
||||
data = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data.append(json.loads(line))
|
||||
|
||||
print(f"📊 Loaded {len(data)} FinanceBench examples")
|
||||
return data
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes with and without embeddings"""
|
||||
|
||||
print("📏 Analyzing index sizes...")
|
||||
|
||||
# Get all index-related files
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem # Remove .leann extension
|
||||
|
||||
sizes = {}
|
||||
total_with_embeddings = 0
|
||||
|
||||
# Core index files
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||
|
||||
for file_path, name in [
|
||||
(index_file, "index"),
|
||||
(meta_file, "metadata"),
|
||||
(passages_file, "passages_text"),
|
||||
(passages_idx_file, "passages_index"),
|
||||
]:
|
||||
if file_path.exists():
|
||||
size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
sizes[name] = size_mb
|
||||
total_with_embeddings += size_mb
|
||||
|
||||
else:
|
||||
sizes[name] = 0
|
||||
|
||||
sizes["total_with_embeddings"] = total_with_embeddings
|
||||
sizes["index_only_mb"] = sizes["index"] # Just the .index file for fair comparison
|
||||
|
||||
print(f" 📁 Total index size: {total_with_embeddings:.1f} MB")
|
||||
print(f" 📁 Index file only: {sizes['index']:.1f} MB")
|
||||
|
||||
return sizes
|
||||
|
||||
def create_compact_index_for_comparison(self, compact_index_path: str) -> dict:
|
||||
"""Create a compact index for comparison purposes"""
|
||||
print("🏗️ Building compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Build compact index with same passages
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||
is_compact=True, # Enable compact storage
|
||||
**meta.get("backend_kwargs", {}),
|
||||
)
|
||||
|
||||
# Load all passages
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||
|
||||
print(f"🔨 Building compact index at {compact_index_path}...")
|
||||
builder.build_index(compact_index_path)
|
||||
|
||||
# Analyze the compact index size
|
||||
temp_evaluator = FinanceBenchEvaluator(compact_index_path)
|
||||
compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
compact_sizes["index_type"] = "compact"
|
||||
|
||||
return compact_sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison purposes"""
|
||||
print("🏗️ Building non-compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Build non-compact index with same passages
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=meta["embedding_model"],
|
||||
embedding_mode=meta.get("embedding_mode", "sentence-transformers"),
|
||||
is_recompute=False, # Disable recompute (store embeddings)
|
||||
is_compact=False, # Disable compact storage
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact"]
|
||||
},
|
||||
)
|
||||
|
||||
# Load all passages
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
builder.add_text(data["text"], metadata=data.get("metadata", {}))
|
||||
|
||||
print(f"🔨 Building non-compact index at {non_compact_index_path}...")
|
||||
builder.build_index(non_compact_index_path)
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_data: list, complexity: int
|
||||
) -> dict:
|
||||
"""Compare performance between non-compact and compact indexes"""
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
|
||||
import time
|
||||
|
||||
from leann import LeannSearcher
|
||||
|
||||
# Test queries
|
||||
test_queries = [item["question"] for item in test_data[:5]]
|
||||
|
||||
results = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
}
|
||||
|
||||
# Test non-compact index (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
query, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["non_compact"]["search_times"].append(search_time)
|
||||
|
||||
# Test compact index (with recompute)
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
|
||||
for query in test_queries:
|
||||
start_time = time.time()
|
||||
_ = compact_searcher.search(
|
||||
query, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["compact"]["search_times"].append(search_time)
|
||||
|
||||
# Calculate averages
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
|
||||
# Performance ratio
|
||||
if results["avg_search_times"]["compact"] > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = float("inf")
|
||||
|
||||
print(
|
||||
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||
)
|
||||
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||
|
||||
# Cleanup
|
||||
non_compact_searcher.cleanup()
|
||||
compact_searcher.cleanup()
|
||||
|
||||
return results
|
||||
|
||||
def evaluate_timing_breakdown(
|
||||
self, data: list[dict], max_samples: Optional[int] = None
|
||||
) -> dict:
|
||||
"""Evaluate timing breakdown and accuracy by hacking LeannChat.ask() for separated timing"""
|
||||
if not self.chat or not self.openai_client:
|
||||
print("⚠️ Skipping timing evaluation (no OpenAI API key provided)")
|
||||
return {
|
||||
"total_questions": 0,
|
||||
"avg_search_time": 0.0,
|
||||
"avg_generation_time": 0.0,
|
||||
"avg_total_time": 0.0,
|
||||
"accuracy": 0.0,
|
||||
}
|
||||
|
||||
print("🔍🤖 Evaluating timing breakdown and accuracy (search + generation)...")
|
||||
|
||||
if max_samples:
|
||||
data = data[:max_samples]
|
||||
print(f"📝 Using first {max_samples} samples for timing evaluation")
|
||||
|
||||
search_times = []
|
||||
generation_times = []
|
||||
total_times = []
|
||||
correct_answers = 0
|
||||
|
||||
for i, item in enumerate(data):
|
||||
question = item["question"]
|
||||
ground_truth = item["answer"]
|
||||
|
||||
try:
|
||||
# Hack: Monkey-patch the ask method to capture internal timing
|
||||
original_ask = self.chat.ask
|
||||
captured_search_time = None
|
||||
captured_generation_time = None
|
||||
|
||||
def patched_ask(*args, **kwargs):
|
||||
nonlocal captured_search_time, captured_generation_time
|
||||
|
||||
# Time the search part
|
||||
search_start = time.time()
|
||||
results = self.chat.searcher.search(args[0], top_k=3, complexity=64)
|
||||
captured_search_time = time.time() - search_start
|
||||
|
||||
# Time the generation part
|
||||
context = "\n\n".join([r.text for r in results])
|
||||
prompt = (
|
||||
"Here is some retrieved context that might help answer your question:\n\n"
|
||||
f"{context}\n\n"
|
||||
f"Question: {args[0]}\n\n"
|
||||
"Please provide the best answer you can based on this context and your knowledge."
|
||||
)
|
||||
|
||||
generation_start = time.time()
|
||||
answer = self.chat.llm.ask(prompt)
|
||||
captured_generation_time = time.time() - generation_start
|
||||
|
||||
return answer
|
||||
|
||||
# Apply the patch
|
||||
self.chat.ask = patched_ask
|
||||
|
||||
# Time the total QA
|
||||
total_start = time.time()
|
||||
generated_answer = self.chat.ask(question)
|
||||
total_time = time.time() - total_start
|
||||
|
||||
# Restore original method
|
||||
self.chat.ask = original_ask
|
||||
|
||||
# Store the timings
|
||||
search_times.append(captured_search_time)
|
||||
generation_times.append(captured_generation_time)
|
||||
total_times.append(total_time)
|
||||
|
||||
# Check accuracy using LLM as judge
|
||||
is_correct = self._check_answer_accuracy(generated_answer, ground_truth, question)
|
||||
if is_correct:
|
||||
correct_answers += 1
|
||||
|
||||
status = "✅" if is_correct else "❌"
|
||||
print(
|
||||
f"Question {i + 1}/{len(data)}: {status} Search={captured_search_time:.3f}s, Gen={captured_generation_time:.3f}s, Total={total_time:.3f}s"
|
||||
)
|
||||
print(f" GT: {ground_truth}")
|
||||
print(f" Gen: {generated_answer[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error: {e}")
|
||||
search_times.append(0.0)
|
||||
generation_times.append(0.0)
|
||||
total_times.append(0.0)
|
||||
|
||||
accuracy = correct_answers / len(data) if data else 0.0
|
||||
|
||||
metrics = {
|
||||
"total_questions": len(data),
|
||||
"avg_search_time": sum(search_times) / len(search_times) if search_times else 0.0,
|
||||
"avg_generation_time": sum(generation_times) / len(generation_times)
|
||||
if generation_times
|
||||
else 0.0,
|
||||
"avg_total_time": sum(total_times) / len(total_times) if total_times else 0.0,
|
||||
"accuracy": accuracy,
|
||||
"correct_answers": correct_answers,
|
||||
"search_times": search_times,
|
||||
"generation_times": generation_times,
|
||||
"total_times": total_times,
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
def _check_answer_accuracy(
|
||||
self, generated_answer: str, ground_truth: str, question: str
|
||||
) -> bool:
|
||||
"""Check if generated answer matches ground truth using LLM as judge"""
|
||||
judge_prompt = f"""You are an expert judge evaluating financial question answering.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Ground Truth Answer: {ground_truth}
|
||||
|
||||
Generated Answer: {generated_answer}
|
||||
|
||||
Task: Determine if the generated answer is factually correct compared to the ground truth. Focus on:
|
||||
1. Numerical accuracy (exact values, units, currency)
|
||||
2. Key financial concepts and terminology
|
||||
3. Overall factual correctness
|
||||
|
||||
For financial data, small formatting differences are OK (e.g., "$1,577" vs "1577 million" vs "$1.577 billion"), but the core numerical value must match.
|
||||
|
||||
Respond with exactly one word: "CORRECT" if the generated answer is factually accurate, or "INCORRECT" if it's wrong or significantly different."""
|
||||
|
||||
try:
|
||||
judge_response = self.openai_client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": judge_prompt}],
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
)
|
||||
judgment = judge_response.choices[0].message.content.strip().upper()
|
||||
return judgment == "CORRECT"
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Judge error: {e}, falling back to string matching")
|
||||
# Fallback to simple string matching
|
||||
gen_clean = generated_answer.strip().lower().replace("$", "").replace(",", "")
|
||||
gt_clean = ground_truth.strip().lower().replace("$", "").replace(",", "")
|
||||
return gt_clean in gen_clean
|
||||
|
||||
def _print_results(self, timing_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 EVALUATION RESULTS")
|
||||
print("=" * 50)
|
||||
|
||||
# Index comparison analysis
|
||||
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||
print("\n📏 Index Comparison Analysis:")
|
||||
current = timing_metrics["current_index"]
|
||||
non_compact = timing_metrics["non_compact_index"]
|
||||
|
||||
print(f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB")
|
||||
print(
|
||||
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
|
||||
print(" Component breakdown (non-compact):")
|
||||
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||
|
||||
# Performance comparison
|
||||
if "performance_comparison" in timing_metrics:
|
||||
perf = timing_metrics["performance_comparison"]
|
||||
print("\n⚡ Performance Comparison:")
|
||||
print(
|
||||
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||
)
|
||||
print(
|
||||
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||
)
|
||||
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||
|
||||
# Legacy single index analysis (fallback)
|
||||
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||
print("\n📏 Index Size Analysis:")
|
||||
print(f" Total index size: {timing_metrics.get('total_with_embeddings', 0):.1f} MB")
|
||||
|
||||
print("\n📊 Accuracy:")
|
||||
print(f" Accuracy: {timing_metrics.get('accuracy', 0) * 100:.1f}%")
|
||||
print(
|
||||
f" Correct Answers: {timing_metrics.get('correct_answers', 0)}/{timing_metrics.get('total_questions', 0)}"
|
||||
)
|
||||
|
||||
print("\n📊 Timing Breakdown:")
|
||||
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||
print(f" Avg Total Time: {timing_metrics.get('avg_total_time', 0):.3f}s")
|
||||
|
||||
if timing_metrics.get("avg_total_time", 0) > 0:
|
||||
search_pct = (
|
||||
timing_metrics.get("avg_search_time", 0)
|
||||
/ timing_metrics.get("avg_total_time", 1)
|
||||
* 100
|
||||
)
|
||||
gen_pct = (
|
||||
timing_metrics.get("avg_generation_time", 0)
|
||||
/ timing_metrics.get("avg_total_time", 1)
|
||||
* 100
|
||||
)
|
||||
print("\n📈 Time Distribution:")
|
||||
print(f" Search: {search_pct:.1f}%")
|
||||
print(f" Generation: {gen_pct:.1f}%")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Modular FinanceBench Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument("--dataset", default="data/financebench_merged.jsonl", help="Dataset path")
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "all"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument("--openai-api-key", help="OpenAI API key for generation evaluation")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--llm-backend", choices=["openai", "hf", "vllm"], default="openai", help="LLM backend"
|
||||
)
|
||||
parser.add_argument("--model-name", default="Qwen3-8B", help="Model name for HF/vLLM")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Check if baseline exists
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_financebench.py first to build the baseline")
|
||||
exit(1)
|
||||
|
||||
if args.stage == "2" or args.stage == "all":
|
||||
# Stage 2: Recall@3 evaluation
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation")
|
||||
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
# Load FinanceBench queries for testing
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
queries = []
|
||||
with open(args.dataset, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
|
||||
# Test with more queries for robust measurement
|
||||
test_queries = queries[:2000]
|
||||
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||
|
||||
# Test with complexity 64
|
||||
complexity = 64
|
||||
recall = evaluator.evaluate_recall_at_3(test_queries, complexity)
|
||||
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
# Shared non-compact index path for Stage 3 and 4
|
||||
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||
complexity = args.complexity
|
||||
|
||||
if args.stage == "3" or args.stage == "all":
|
||||
# Stage 3: Binary search for 90% recall complexity (using non-compact index for speed)
|
||||
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||
print(
|
||||
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||
)
|
||||
|
||||
# Create non-compact index for binary search (will be reused in Stage 4)
|
||||
print("🏗️ Creating non-compact index for binary search...")
|
||||
evaluator = FinanceBenchEvaluator(args.index)
|
||||
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact index for binary search
|
||||
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
|
||||
# Load queries for testing
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
queries = []
|
||||
with open(args.dataset, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
|
||||
# Use more queries for robust measurement
|
||||
test_queries = queries[:200]
|
||||
print(f"🧪 Testing with {len(test_queries)} queries")
|
||||
|
||||
# Binary search for 90% recall complexity (without recompute for speed)
|
||||
target_recall = 0.9
|
||||
min_complexity, max_complexity = 1, 32
|
||||
|
||||
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||
|
||||
best_complexity = None
|
||||
best_recall = 0.0
|
||||
|
||||
while min_complexity <= max_complexity:
|
||||
mid_complexity = (min_complexity + max_complexity) // 2
|
||||
|
||||
print(
|
||||
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||
)
|
||||
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_queries, mid_complexity, recompute_embeddings=False
|
||||
)
|
||||
|
||||
print(
|
||||
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if recall >= target_recall:
|
||||
best_complexity = mid_complexity
|
||||
best_recall = recall
|
||||
max_complexity = mid_complexity - 1
|
||||
print(" ✅ Target reached! Searching for lower complexity...")
|
||||
else:
|
||||
min_complexity = mid_complexity + 1
|
||||
print(" ❌ Below target. Searching for higher complexity...")
|
||||
|
||||
if best_complexity is not None:
|
||||
print("\n🎯 Optimal complexity found!")
|
||||
print(f" Complexity: {best_complexity}")
|
||||
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||
|
||||
# Test a few complexities around the optimal one for verification
|
||||
print("\n🔬 Verification test around optimal complexity:")
|
||||
verification_complexities = [
|
||||
max(1, best_complexity - 2),
|
||||
max(1, best_complexity - 1),
|
||||
best_complexity,
|
||||
best_complexity + 1,
|
||||
best_complexity + 2,
|
||||
]
|
||||
|
||||
for complexity in verification_complexities:
|
||||
if complexity <= 512: # reasonable upper bound
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_queries, complexity, recompute_embeddings=False
|
||||
)
|
||||
status = "✅" if recall >= target_recall else "❌"
|
||||
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||
|
||||
# Now test the optimal complexity with compact index and recompute for comparison
|
||||
print(
|
||||
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||
)
|
||||
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||
test_queries[:10], best_complexity, recompute_embeddings=True
|
||||
)
|
||||
print(
|
||||
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||
)
|
||||
complexity = best_complexity
|
||||
print(
|
||||
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||
)
|
||||
compact_evaluator.cleanup()
|
||||
else:
|
||||
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||
print("All tested complexities were below target.")
|
||||
|
||||
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||
binary_search_evaluator.cleanup()
|
||||
evaluator.cleanup()
|
||||
|
||||
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||
|
||||
if args.stage == "4" or args.stage == "all":
|
||||
# Stage 4: Comprehensive evaluation with dual index comparison
|
||||
print("🚀 Starting Stage 4: Comprehensive evaluation with dual index comparison")
|
||||
|
||||
# Use FinanceBench evaluator for QA evaluation
|
||||
evaluator = FinanceBenchEvaluator(
|
||||
args.index, args.openai_api_key if args.llm_backend == "openai" else None
|
||||
)
|
||||
|
||||
print("📖 Loading FinanceBench dataset...")
|
||||
data = evaluator.load_dataset(args.dataset)
|
||||
|
||||
# Step 1: Analyze current (compact) index
|
||||
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||
compact_size_metrics["index_type"] = "compact"
|
||||
|
||||
# Step 2: Use existing non-compact index or create if needed
|
||||
from pathlib import Path
|
||||
|
||||
if Path(non_compact_index_path).exists():
|
||||
print(
|
||||
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||
)
|
||||
temp_evaluator = FinanceBenchEvaluator(non_compact_index_path)
|
||||
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_size_metrics["index_type"] = "non_compact"
|
||||
else:
|
||||
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||
non_compact_index_path
|
||||
)
|
||||
|
||||
# Step 3: Compare index sizes
|
||||
print("\n📊 Index size comparison:")
|
||||
print(
|
||||
f" Compact index (current): {compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Non-compact index: {non_compact_size_metrics['total_with_embeddings']:.1f} MB"
|
||||
)
|
||||
print("\n📊 Index-only size comparison (.index file only):")
|
||||
print(f" Compact index: {compact_size_metrics['index_only_mb']:.1f} MB")
|
||||
print(f" Non-compact index: {non_compact_size_metrics['index_only_mb']:.1f} MB")
|
||||
# Use index-only size for fair comparison (same as Enron emails)
|
||||
storage_saving = (
|
||||
(non_compact_size_metrics["index_only_mb"] - compact_size_metrics["index_only_mb"])
|
||||
/ non_compact_size_metrics["index_only_mb"]
|
||||
* 100
|
||||
)
|
||||
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||
|
||||
# Step 4: Performance comparison between the two indexes
|
||||
if complexity is None:
|
||||
raise ValueError("Complexity is required for performance comparison")
|
||||
|
||||
print("\n⚡ Performance comparison between indexes...")
|
||||
performance_metrics = evaluator.compare_index_performance(
|
||||
non_compact_index_path, args.index, data[:10], complexity=complexity
|
||||
)
|
||||
|
||||
# Step 5: Generation evaluation
|
||||
test_samples = 20
|
||||
print(f"\n🧪 Testing with first {test_samples} samples for generation analysis")
|
||||
|
||||
if args.llm_backend == "openai" and args.openai_api_key:
|
||||
print("🔍🤖 Running OpenAI-based generation evaluation...")
|
||||
evaluation_start = time.time()
|
||||
timing_metrics = evaluator.evaluate_timing_breakdown(data[:test_samples])
|
||||
evaluation_time = time.time() - evaluation_start
|
||||
else:
|
||||
print(
|
||||
f"🔍🤖 Running {args.llm_backend} generation evaluation with {args.model_name}..."
|
||||
)
|
||||
try:
|
||||
# Load LLM
|
||||
if args.llm_backend == "hf":
|
||||
tokenizer, model = load_hf_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_hf(tokenizer, model, prompt)
|
||||
else: # vllm
|
||||
llm, sampling_params = load_vllm_model(args.model_name)
|
||||
|
||||
def llm_func(prompt):
|
||||
return generate_vllm(llm, sampling_params, prompt)
|
||||
|
||||
# Simple generation evaluation
|
||||
queries = [item["question"] for item in data[:test_samples]]
|
||||
gen_results = evaluate_rag(
|
||||
evaluator.searcher,
|
||||
llm_func,
|
||||
queries,
|
||||
domain="finance",
|
||||
complexity=complexity,
|
||||
)
|
||||
|
||||
timing_metrics = {
|
||||
"total_questions": len(queries),
|
||||
"avg_search_time": gen_results["avg_search_time"],
|
||||
"avg_generation_time": gen_results["avg_generation_time"],
|
||||
"results": gen_results["results"],
|
||||
}
|
||||
evaluation_time = time.time()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation evaluation failed: {e}")
|
||||
timing_metrics = {
|
||||
"total_questions": 0,
|
||||
"avg_search_time": 0,
|
||||
"avg_generation_time": 0,
|
||||
}
|
||||
evaluation_time = 0
|
||||
|
||||
# Combine all metrics
|
||||
combined_metrics = {
|
||||
**timing_metrics,
|
||||
"total_evaluation_time": evaluation_time,
|
||||
"current_index": compact_size_metrics,
|
||||
"non_compact_index": non_compact_size_metrics,
|
||||
"performance_comparison": performance_metrics,
|
||||
"storage_saving_percent": storage_saving,
|
||||
}
|
||||
|
||||
# Print results
|
||||
print("\n📊 Generation Results:")
|
||||
print(f" Total Questions: {timing_metrics.get('total_questions', 0)}")
|
||||
print(f" Avg Search Time: {timing_metrics.get('avg_search_time', 0):.3f}s")
|
||||
print(f" Avg Generation Time: {timing_metrics.get('avg_generation_time', 0):.3f}s")
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
print(f"\n💾 Saving results to {args.output}...")
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(combined_metrics, f, indent=2, default=str)
|
||||
print(f"✅ Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage == "all":
|
||||
print("🎉 All evaluation stages completed successfully!")
|
||||
print("\n📋 Summary:")
|
||||
print(" Stage 2: ✅ Recall@3 evaluation completed")
|
||||
print(" Stage 3: ✅ Optimal complexity found")
|
||||
print(" Stage 4: ✅ Generation accuracy & timing evaluation completed")
|
||||
print("\n🔧 Recommended next steps:")
|
||||
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||
print(" - Review accuracy and timing breakdown for performance optimization")
|
||||
print(" - Run full evaluation on complete dataset if needed")
|
||||
|
||||
# Clean up non-compact index after all stages complete
|
||||
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||
from pathlib import Path
|
||||
|
||||
if Path(non_compact_index_path).exists():
|
||||
temp_index_dir = Path(non_compact_index_path).parent
|
||||
temp_index_name = Path(non_compact_index_path).name
|
||||
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||
temp_file.unlink()
|
||||
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||
else:
|
||||
print("📝 No temporary index to clean up")
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
462
benchmarks/financebench/setup_financebench.py
Executable file
462
benchmarks/financebench/setup_financebench.py
Executable file
@@ -0,0 +1,462 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FinanceBench Complete Setup Script
|
||||
Downloads all PDFs and builds full LEANN datastore
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import pymupdf
|
||||
import requests
|
||||
from leann import LeannBuilder, LeannSearcher
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class FinanceBenchSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.base_dir = Path(__file__).parent # benchmarks/financebench/
|
||||
self.data_dir = self.base_dir / data_dir
|
||||
self.pdf_dir = self.data_dir / "pdfs"
|
||||
self.dataset_file = self.data_dir / "financebench_merged.jsonl"
|
||||
self.index_dir = self.data_dir / "index"
|
||||
self.download_lock = Lock()
|
||||
|
||||
def download_dataset(self):
|
||||
"""Download the main FinanceBench dataset"""
|
||||
print("📊 Downloading FinanceBench dataset...")
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.dataset_file.exists():
|
||||
print(f"✅ Dataset already exists: {self.dataset_file}")
|
||||
return
|
||||
|
||||
url = "https://huggingface.co/datasets/PatronusAI/financebench/raw/main/financebench_merged.jsonl"
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(self.dataset_file, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
print(f"✅ Dataset downloaded: {self.dataset_file}")
|
||||
|
||||
def get_pdf_list(self):
|
||||
"""Get list of all PDF files from GitHub"""
|
||||
print("📋 Fetching PDF list from GitHub...")
|
||||
|
||||
response = requests.get(
|
||||
"https://api.github.com/repos/patronus-ai/financebench/contents/pdfs"
|
||||
)
|
||||
response.raise_for_status()
|
||||
pdf_files = response.json()
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files")
|
||||
return pdf_files
|
||||
|
||||
def download_single_pdf(self, pdf_info, position):
|
||||
"""Download a single PDF file"""
|
||||
pdf_name = pdf_info["name"]
|
||||
pdf_path = self.pdf_dir / pdf_name
|
||||
|
||||
# Skip if already downloaded
|
||||
if pdf_path.exists() and pdf_path.stat().st_size > 0:
|
||||
return f"✅ {pdf_name} (cached)"
|
||||
|
||||
try:
|
||||
# Download PDF
|
||||
response = requests.get(pdf_info["download_url"], timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Write to file
|
||||
with self.download_lock:
|
||||
with open(pdf_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return f"✅ {pdf_name} ({len(response.content) // 1024}KB)"
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ {pdf_name}: {e!s}"
|
||||
|
||||
def download_all_pdfs(self, max_workers: int = 5):
|
||||
"""Download all PDF files with parallel processing"""
|
||||
self.pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_files = self.get_pdf_list()
|
||||
|
||||
print(f"📥 Downloading {len(pdf_files)} PDFs with {max_workers} workers...")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all download tasks
|
||||
future_to_pdf = {
|
||||
executor.submit(self.download_single_pdf, pdf_info, i): pdf_info["name"]
|
||||
for i, pdf_info in enumerate(pdf_files)
|
||||
}
|
||||
|
||||
# Process completed downloads with progress bar
|
||||
with tqdm(total=len(pdf_files), desc="Downloading PDFs") as pbar:
|
||||
for future in as_completed(future_to_pdf):
|
||||
result = future.result()
|
||||
pbar.set_postfix_str(result.split()[-1] if "✅" in result else "Error")
|
||||
pbar.update(1)
|
||||
|
||||
# Verify downloads
|
||||
downloaded_pdfs = list(self.pdf_dir.glob("*.pdf"))
|
||||
print(f"✅ Successfully downloaded {len(downloaded_pdfs)}/{len(pdf_files)} PDFs")
|
||||
|
||||
# Show any failures
|
||||
missing_pdfs = []
|
||||
for pdf_info in pdf_files:
|
||||
pdf_path = self.pdf_dir / pdf_info["name"]
|
||||
if not pdf_path.exists() or pdf_path.stat().st_size == 0:
|
||||
missing_pdfs.append(pdf_info["name"])
|
||||
|
||||
if missing_pdfs:
|
||||
print(f"⚠️ Failed to download {len(missing_pdfs)} PDFs:")
|
||||
for pdf in missing_pdfs[:5]: # Show first 5
|
||||
print(f" - {pdf}")
|
||||
if len(missing_pdfs) > 5:
|
||||
print(f" ... and {len(missing_pdfs) - 5} more")
|
||||
|
||||
def build_leann_index(
|
||||
self,
|
||||
backend: str = "hnsw",
|
||||
embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
):
|
||||
"""Build LEANN index from all PDFs"""
|
||||
print(f"🏗️ Building LEANN index with {backend} backend...")
|
||||
|
||||
# Check if we have PDFs
|
||||
pdf_files = list(self.pdf_dir.glob("*.pdf"))
|
||||
if not pdf_files:
|
||||
raise RuntimeError("No PDF files found! Run download first.")
|
||||
|
||||
print(f"Found {len(pdf_files)} PDF files to process")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Initialize builder with standard compact configuration
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model=embedding_model,
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=True, # Enable recompute (no stored embeddings)
|
||||
is_compact=True, # Enable compact storage (pruned)
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Process PDFs and extract text
|
||||
total_chunks = 0
|
||||
failed_pdfs = []
|
||||
|
||||
for pdf_path in tqdm(pdf_files, desc="Processing PDFs"):
|
||||
try:
|
||||
chunks = self.extract_pdf_text(pdf_path)
|
||||
for chunk in chunks:
|
||||
builder.add_text(chunk["text"], metadata=chunk["metadata"])
|
||||
total_chunks += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to process {pdf_path.name}: {e}")
|
||||
failed_pdfs.append(pdf_path.name)
|
||||
continue
|
||||
|
||||
# Build index in index directory
|
||||
self.index_dir.mkdir(parents=True, exist_ok=True)
|
||||
index_path = self.index_dir / f"financebench_full_{backend}.leann"
|
||||
print(f"🔨 Building index: {index_path}")
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
|
||||
print("✅ Index built successfully!")
|
||||
print(f" 📁 Index path: {index_path}")
|
||||
print(f" 📊 Total chunks: {total_chunks:,}")
|
||||
print(f" 📄 Processed PDFs: {len(pdf_files) - len(failed_pdfs)}/{len(pdf_files)}")
|
||||
print(f" ⏱️ Build time: {build_time:.1f}s")
|
||||
|
||||
if failed_pdfs:
|
||||
print(f" ⚠️ Failed PDFs: {failed_pdfs}")
|
||||
|
||||
return str(index_path)
|
||||
|
||||
def build_faiss_flat_baseline(self, index_path: str, output_dir: str = "baseline"):
|
||||
"""Build FAISS flat baseline using the same embeddings as LEANN index"""
|
||||
print("🔨 Building FAISS Flat baseline...")
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from leann.api import compute_embeddings
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Read metadata from the built index
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
import json
|
||||
|
||||
meta = json.loads(f.read())
|
||||
|
||||
embedding_model = meta["embedding_model"]
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not os.path.isabs(passage_file):
|
||||
index_dir = os.path.dirname(index_path)
|
||||
passage_file = os.path.join(index_dir, os.path.basename(passage_file))
|
||||
|
||||
print(f"📊 Loading passages from {passage_file}...")
|
||||
print(f"🤖 Using embedding model: {embedding_model}")
|
||||
|
||||
# Load all passages for baseline
|
||||
passages = []
|
||||
passage_ids = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"])
|
||||
|
||||
print(f"📄 Loaded {len(passages)} passages")
|
||||
|
||||
# Compute embeddings using the same method as LEANN
|
||||
print("🧮 Computing embeddings...")
|
||||
embeddings = compute_embeddings(
|
||||
passages,
|
||||
embedding_model,
|
||||
mode="sentence-transformers",
|
||||
use_server=False,
|
||||
)
|
||||
|
||||
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||
|
||||
# Build FAISS flat index
|
||||
print("🏗️ Building FAISS IndexFlatIP...")
|
||||
dimension = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# Add embeddings to flat index
|
||||
embeddings_f32 = embeddings.astype(np.float32)
|
||||
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||
|
||||
# Save index and metadata
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as f:
|
||||
pickle.dump(passage_ids, f)
|
||||
|
||||
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||
print(f"✅ Metadata saved to {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
|
||||
return baseline_path
|
||||
|
||||
def extract_pdf_text(self, pdf_path: Path) -> list[dict]:
|
||||
"""Extract and chunk text from a PDF file"""
|
||||
chunks = []
|
||||
doc = pymupdf.open(pdf_path)
|
||||
|
||||
for page_num in range(len(doc)):
|
||||
page = doc[page_num]
|
||||
text = page.get_text() # type: ignore
|
||||
|
||||
if not text.strip():
|
||||
continue
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source_file": pdf_path.name,
|
||||
"page_number": page_num + 1,
|
||||
"document_type": "10K" if "10K" in pdf_path.name else "10Q",
|
||||
"company": pdf_path.name.split("_")[0],
|
||||
"doc_period": self.extract_year_from_filename(pdf_path.name),
|
||||
}
|
||||
|
||||
# Use recursive character splitting like LangChain
|
||||
if len(text.split()) > 500:
|
||||
# Split by double newlines (paragraphs)
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
|
||||
current_chunk = ""
|
||||
for para in paragraphs:
|
||||
# If adding this paragraph would make chunk too long, save current chunk
|
||||
if current_chunk and len((current_chunk + " " + para).split()) > 300:
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
current_chunk = para
|
||||
else:
|
||||
current_chunk = (current_chunk + " " + para).strip()
|
||||
|
||||
# Add the last chunk
|
||||
if current_chunk.strip():
|
||||
chunks.append(
|
||||
{
|
||||
"text": current_chunk.strip(),
|
||||
"metadata": {
|
||||
**metadata,
|
||||
"chunk_id": f"page_{page_num + 1}_chunk_{len(chunks)}",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Page is short enough, use as single chunk
|
||||
chunks.append(
|
||||
{
|
||||
"text": text.strip(),
|
||||
"metadata": {**metadata, "chunk_id": f"page_{page_num + 1}"},
|
||||
}
|
||||
)
|
||||
|
||||
doc.close()
|
||||
return chunks
|
||||
|
||||
def extract_year_from_filename(self, filename: str) -> str:
|
||||
"""Extract year from PDF filename"""
|
||||
# Try to find 4-digit year in filename
|
||||
|
||||
match = re.search(r"(\d{4})", filename)
|
||||
return match.group(1) if match else "unknown"
|
||||
|
||||
def verify_setup(self, index_path: str):
|
||||
"""Verify the setup by testing a simple query"""
|
||||
print("🧪 Verifying setup with test query...")
|
||||
|
||||
try:
|
||||
searcher = LeannSearcher(index_path)
|
||||
|
||||
# Test query
|
||||
test_query = "What is the capital expenditure for 3M in 2018?"
|
||||
results = searcher.search(test_query, top_k=3)
|
||||
|
||||
print(f"✅ Test query successful! Found {len(results)} results:")
|
||||
for i, result in enumerate(results, 1):
|
||||
company = result.metadata.get("company", "Unknown")
|
||||
year = result.metadata.get("doc_period", "Unknown")
|
||||
page = result.metadata.get("page_number", "Unknown")
|
||||
print(f" {i}. {company} {year} (page {page}) - Score: {result.score:.3f}")
|
||||
print(f" {result.text[:100]}...")
|
||||
|
||||
searcher.cleanup()
|
||||
print("✅ Setup verification completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Setup verification failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup FinanceBench with full PDF datastore")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument(
|
||||
"--backend", choices=["hnsw", "diskann"], default="hnsw", help="LEANN backend"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-model",
|
||||
default="sentence-transformers/all-mpnet-base-v2",
|
||||
help="Embedding model",
|
||||
)
|
||||
parser.add_argument("--max-workers", type=int, default=5, help="Parallel download workers")
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip PDF download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
parser.add_argument(
|
||||
"--build-baseline-only",
|
||||
action="store_true",
|
||||
help="Only build FAISS baseline from existing index",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🏦 FinanceBench Complete Setup")
|
||||
print("=" * 50)
|
||||
|
||||
setup = FinanceBenchSetup(args.data_dir)
|
||||
|
||||
try:
|
||||
if args.build_baseline_only:
|
||||
# Only build baseline from existing index
|
||||
index_path = setup.index_dir / f"financebench_full_{args.backend}"
|
||||
index_file = f"{index_path}.index"
|
||||
meta_file = f"{index_path}.leann.meta.json"
|
||||
|
||||
if not os.path.exists(index_file) or not os.path.exists(meta_file):
|
||||
print("❌ Index files not found:")
|
||||
print(f" Index: {index_file}")
|
||||
print(f" Meta: {meta_file}")
|
||||
print("💡 Run without --build-baseline-only to build the index first")
|
||||
exit(1)
|
||||
|
||||
print(f"🔨 Building baseline from existing index: {index_path}")
|
||||
baseline_path = setup.build_faiss_flat_baseline(str(index_path))
|
||||
print(f"✅ Baseline built at {baseline_path}")
|
||||
return
|
||||
|
||||
# Step 1: Download dataset
|
||||
setup.download_dataset()
|
||||
|
||||
# Step 2: Download PDFs
|
||||
if not args.skip_download:
|
||||
setup.download_all_pdfs(max_workers=args.max_workers)
|
||||
else:
|
||||
print("⏭️ Skipping PDF download")
|
||||
|
||||
# Step 3: Build LEANN index
|
||||
if not args.skip_build:
|
||||
index_path = setup.build_leann_index(
|
||||
backend=args.backend, embedding_model=args.embedding_model
|
||||
)
|
||||
|
||||
# Step 4: Build FAISS flat baseline
|
||||
print("\n🔨 Building FAISS flat baseline...")
|
||||
baseline_path = setup.build_faiss_flat_baseline(index_path)
|
||||
print(f"✅ Baseline built at {baseline_path}")
|
||||
|
||||
# Step 5: Verify setup
|
||||
setup.verify_setup(index_path)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
|
||||
print("\n🎉 FinanceBench setup completed!")
|
||||
print(f"📁 Data directory: {setup.data_dir.absolute()}")
|
||||
print("\nNext steps:")
|
||||
print(
|
||||
"1. Run evaluation: python evaluate_financebench.py --index data/index/financebench_full_hnsw.leann"
|
||||
)
|
||||
print(
|
||||
"2. Or test manually: python -c \"from leann import LeannSearcher; s = LeannSearcher('data/index/financebench_full_hnsw.leann'); print(s.search('3M capital expenditure 2018'))\""
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
214
benchmarks/financebench/verify_recall.py
Normal file
214
benchmarks/financebench/verify_recall.py
Normal file
@@ -0,0 +1,214 @@
|
||||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.9"
|
||||
# dependencies = [
|
||||
# "faiss-cpu",
|
||||
# "numpy",
|
||||
# "sentence-transformers",
|
||||
# "torch",
|
||||
# "tqdm",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
"""
|
||||
Independent recall verification script using standard FAISS.
|
||||
Creates two indexes (HNSW and Flat) and compares recall@3 at different complexities.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def compute_embeddings_direct(chunks: list[str], model_name: str) -> np.ndarray:
|
||||
"""
|
||||
Direct embedding computation using sentence-transformers.
|
||||
Copied logic to avoid dependency issues.
|
||||
"""
|
||||
print(f"Loading model: {model_name}")
|
||||
model = SentenceTransformer(model_name)
|
||||
|
||||
print(f"Computing embeddings for {len(chunks)} chunks...")
|
||||
embeddings = model.encode(
|
||||
chunks,
|
||||
show_progress_bar=True,
|
||||
batch_size=32,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=False,
|
||||
)
|
||||
|
||||
return embeddings.astype(np.float32)
|
||||
|
||||
|
||||
def load_financebench_queries(dataset_path: str, max_queries: int = 200) -> list[str]:
|
||||
"""Load FinanceBench queries from dataset"""
|
||||
queries = []
|
||||
with open(dataset_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
queries.append(data["question"])
|
||||
if len(queries) >= max_queries:
|
||||
break
|
||||
return queries
|
||||
|
||||
|
||||
def load_passages_from_leann_index(index_path: str) -> tuple[list[str], list[str]]:
|
||||
"""Load passages from LEANN index structure"""
|
||||
meta_path = f"{index_path}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
index_dir = Path(index_path).parent
|
||||
passage_file = index_dir / Path(passage_file).name
|
||||
|
||||
print(f"Loading passages from {passage_file}")
|
||||
|
||||
passages = []
|
||||
passage_ids = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in tqdm(f, desc="Loading passages"):
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
passages.append(data["text"])
|
||||
passage_ids.append(data["id"])
|
||||
|
||||
print(f"Loaded {len(passages)} passages")
|
||||
return passages, passage_ids
|
||||
|
||||
|
||||
def build_faiss_indexes(embeddings: np.ndarray) -> tuple[faiss.Index, faiss.Index]:
|
||||
"""Build FAISS indexes: Flat (ground truth) and HNSW"""
|
||||
dimension = embeddings.shape[1]
|
||||
|
||||
# Build Flat index (ground truth)
|
||||
print("Building FAISS IndexFlatIP (ground truth)...")
|
||||
flat_index = faiss.IndexFlatIP(dimension)
|
||||
flat_index.add(embeddings)
|
||||
|
||||
# Build HNSW index
|
||||
print("Building FAISS IndexHNSWFlat...")
|
||||
M = 32 # Same as LEANN default
|
||||
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
|
||||
hnsw_index.hnsw.efConstruction = 200 # Same as LEANN default
|
||||
hnsw_index.add(embeddings)
|
||||
|
||||
print(f"Built indexes with {flat_index.ntotal} vectors, dimension {dimension}")
|
||||
return flat_index, hnsw_index
|
||||
|
||||
|
||||
def evaluate_recall_at_k(
|
||||
query_embeddings: np.ndarray,
|
||||
flat_index: faiss.Index,
|
||||
hnsw_index: faiss.Index,
|
||||
passage_ids: list[str],
|
||||
k: int = 3,
|
||||
ef_search: int = 64,
|
||||
) -> float:
|
||||
"""Evaluate recall@k comparing HNSW vs Flat"""
|
||||
|
||||
# Set search parameters for HNSW
|
||||
hnsw_index.hnsw.efSearch = ef_search
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = query_embeddings.shape[0]
|
||||
|
||||
for i in range(num_queries):
|
||||
query = query_embeddings[i : i + 1] # Keep 2D shape
|
||||
|
||||
# Get ground truth from Flat index (standard FAISS API)
|
||||
flat_distances, flat_indices = flat_index.search(query, k)
|
||||
ground_truth_ids = {passage_ids[idx] for idx in flat_indices[0]}
|
||||
|
||||
# Get results from HNSW index (standard FAISS API)
|
||||
hnsw_distances, hnsw_indices = hnsw_index.search(query, k)
|
||||
hnsw_ids = {passage_ids[idx] for idx in hnsw_indices[0]}
|
||||
|
||||
# Calculate recall
|
||||
intersection = ground_truth_ids.intersection(hnsw_ids)
|
||||
recall = len(intersection) / k
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: Recall@{k} = {recall:.3f}")
|
||||
print(f" Flat: {list(ground_truth_ids)}")
|
||||
print(f" HNSW: {list(hnsw_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
return avg_recall
|
||||
|
||||
|
||||
def main():
|
||||
# Configuration
|
||||
dataset_path = "data/financebench_merged.jsonl"
|
||||
index_path = "data/index/financebench_full_hnsw.leann"
|
||||
embedding_model = "sentence-transformers/all-mpnet-base-v2"
|
||||
|
||||
print("🔍 FAISS Recall Verification")
|
||||
print("=" * 50)
|
||||
|
||||
# Check if files exist
|
||||
if not Path(dataset_path).exists():
|
||||
print(f"❌ Dataset not found: {dataset_path}")
|
||||
return
|
||||
if not Path(f"{index_path}.meta.json").exists():
|
||||
print(f"❌ Index metadata not found: {index_path}.meta.json")
|
||||
return
|
||||
|
||||
# Load data
|
||||
print("📖 Loading FinanceBench queries...")
|
||||
queries = load_financebench_queries(dataset_path, max_queries=50)
|
||||
print(f"Loaded {len(queries)} queries")
|
||||
|
||||
print("📄 Loading passages from LEANN index...")
|
||||
passages, passage_ids = load_passages_from_leann_index(index_path)
|
||||
|
||||
# Compute embeddings
|
||||
print("🧮 Computing passage embeddings...")
|
||||
passage_embeddings = compute_embeddings_direct(passages, embedding_model)
|
||||
|
||||
print("🧮 Computing query embeddings...")
|
||||
query_embeddings = compute_embeddings_direct(queries, embedding_model)
|
||||
|
||||
# Build FAISS indexes
|
||||
print("🏗️ Building FAISS indexes...")
|
||||
flat_index, hnsw_index = build_faiss_indexes(passage_embeddings)
|
||||
|
||||
# Test different efSearch values (equivalent to LEANN complexity)
|
||||
print("\n📊 Evaluating Recall@3 at different efSearch values...")
|
||||
ef_search_values = [16, 32, 64, 128, 256]
|
||||
|
||||
for ef_search in ef_search_values:
|
||||
print(f"\n🧪 Testing efSearch = {ef_search}")
|
||||
start_time = time.time()
|
||||
|
||||
recall = evaluate_recall_at_k(
|
||||
query_embeddings, flat_index, hnsw_index, passage_ids, k=3, ef_search=ef_search
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
f"📈 efSearch {ef_search}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%) in {elapsed:.2f}s"
|
||||
)
|
||||
|
||||
print("\n✅ Verification completed!")
|
||||
print("\n📋 Summary:")
|
||||
print(" - Built independent FAISS Flat and HNSW indexes")
|
||||
print(" - Compared recall@3 at different efSearch values")
|
||||
print(" - Used same embedding model as LEANN")
|
||||
print(" - This validates LEANN's recall measurements")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
benchmarks/laion/.gitignore
vendored
Normal file
1
benchmarks/laion/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
data/
|
||||
199
benchmarks/laion/README.md
Normal file
199
benchmarks/laion/README.md
Normal file
@@ -0,0 +1,199 @@
|
||||
# LAION Multimodal Benchmark
|
||||
|
||||
A multimodal benchmark for evaluating image retrieval and generation performance using LEANN with CLIP embeddings and Qwen2.5-VL for multimodal generation on LAION dataset subset.
|
||||
|
||||
## Overview
|
||||
|
||||
This benchmark evaluates:
|
||||
- **Image retrieval timing** using caption-based queries
|
||||
- **Recall@K performance** for image search
|
||||
- **Complexity analysis** across different search parameters
|
||||
- **Index size and storage efficiency**
|
||||
- **Multimodal generation** with Qwen2.5-VL for image understanding and description
|
||||
|
||||
## Dataset Configuration
|
||||
|
||||
- **Dataset**: LAION-400M subset (10,000 images)
|
||||
- **Embeddings**: Pre-computed CLIP ViT-B/32 (512 dimensions)
|
||||
- **Queries**: 200 random captions from the dataset
|
||||
- **Ground Truth**: Self-recall (query caption → original image)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Setup the benchmark
|
||||
|
||||
```bash
|
||||
cd benchmarks/laion
|
||||
python setup_laion.py --num-samples 10000 --num-queries 200
|
||||
```
|
||||
|
||||
This will:
|
||||
- Create dummy LAION data (10K samples)
|
||||
- Generate CLIP embeddings (512-dim)
|
||||
- Build LEANN index with HNSW backend
|
||||
- Create 200 evaluation queries
|
||||
|
||||
### 2. Run evaluation
|
||||
|
||||
```bash
|
||||
# Run all evaluation stages
|
||||
python evaluate_laion.py --index data/laion_index.leann
|
||||
|
||||
# Run specific stages
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 2 # Recall evaluation
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 3 # Complexity analysis
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 4 # Index comparison
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 5 # Multimodal generation
|
||||
|
||||
# Multimodal generation with Qwen2.5-VL
|
||||
python evaluate_laion.py --index data/laion_index.leann --stage 5 --model-name Qwen/Qwen2.5-VL-7B-Instruct
|
||||
```
|
||||
|
||||
### 3. Save results
|
||||
|
||||
```bash
|
||||
python evaluate_laion.py --index data/laion_index.leann --output results.json
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Setup Options
|
||||
```bash
|
||||
python setup_laion.py \
|
||||
--num-samples 10000 \
|
||||
--num-queries 200 \
|
||||
--index-path data/laion_index.leann \
|
||||
--backend hnsw
|
||||
```
|
||||
|
||||
### Evaluation Options
|
||||
```bash
|
||||
python evaluate_laion.py \
|
||||
--index data/laion_index.leann \
|
||||
--queries data/evaluation_queries.jsonl \
|
||||
--complexity 64 \
|
||||
--top-k 3 \
|
||||
--num-samples 100 \
|
||||
--stage all
|
||||
```
|
||||
|
||||
## Evaluation Stages
|
||||
|
||||
### Stage 2: Recall Evaluation
|
||||
- Evaluates Recall@3 for multimodal retrieval
|
||||
- Compares LEANN vs FAISS baseline performance
|
||||
- Self-recall: query caption should retrieve original image
|
||||
|
||||
### Stage 3: Complexity Analysis
|
||||
- Binary search for optimal complexity (90% recall target)
|
||||
- Tests performance across different complexity levels
|
||||
- Analyzes speed vs. accuracy tradeoffs
|
||||
|
||||
### Stage 4: Index Comparison
|
||||
- Compares compact vs non-compact index sizes
|
||||
- Measures search performance differences
|
||||
- Reports storage efficiency and speed ratios
|
||||
|
||||
### Stage 5: Multimodal Generation
|
||||
- Uses Qwen2.5-VL for image understanding and description
|
||||
- Retrieval-Augmented Generation (RAG) with multimodal context
|
||||
- Measures both search and generation timing
|
||||
|
||||
## Output Metrics
|
||||
|
||||
### Timing Metrics
|
||||
- Average/median/min/max search time
|
||||
- Standard deviation
|
||||
- Searches per second
|
||||
- Latency in milliseconds
|
||||
|
||||
### Recall Metrics
|
||||
- Recall@3 percentage for image retrieval
|
||||
- Number of queries with ground truth
|
||||
|
||||
### Index Metrics
|
||||
- Total index size (MB)
|
||||
- Component breakdown (index, passages, metadata)
|
||||
- Storage savings (compact vs non-compact)
|
||||
- Backend and embedding model info
|
||||
|
||||
### Generation Metrics (Stage 5)
|
||||
- Average search time per query
|
||||
- Average generation time per query
|
||||
- Time distribution (search vs generation)
|
||||
- Sample multimodal responses
|
||||
- Model: Qwen2.5-VL performance
|
||||
|
||||
## Benchmark Results
|
||||
|
||||
### LEANN-RAG Performance (CLIP ViT-L/14 + Qwen2.5-VL)
|
||||
|
||||
**Stage 3: Optimal Complexity Analysis**
|
||||
- **Optimal Complexity**: 85 (achieving 90% Recall@3)
|
||||
- **Binary Search Range**: 1-128
|
||||
- **Target Recall**: 90%
|
||||
- **Index Type**: Non-compact (for fast binary search)
|
||||
|
||||
**Stage 5: Multimodal Generation Performance (Qwen2.5-VL)**
|
||||
- **Total Queries**: 20
|
||||
- **Average Search Time**: 1.200s per query
|
||||
- **Average Generation Time**: 6.558s per query
|
||||
- **Time Distribution**: Search 15.5%, Generation 84.5%
|
||||
- **LLM Backend**: HuggingFace transformers
|
||||
- **Model**: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
- **Optimal Complexity**: 85
|
||||
|
||||
**System Performance:**
|
||||
- **Index Size**: ~10,000 image embeddings from LAION subset
|
||||
- **Embedding Model**: CLIP ViT-L/14 (768 dimensions)
|
||||
- **Backend**: HNSW with cosine distance
|
||||
|
||||
### Example Results
|
||||
|
||||
```
|
||||
🎯 LAION MULTIMODAL BENCHMARK RESULTS
|
||||
============================================================
|
||||
|
||||
📊 Multimodal Generation Results:
|
||||
Total Queries: 20
|
||||
Avg Search Time: 1.200s
|
||||
Avg Generation Time: 6.558s
|
||||
Time Distribution: Search 15.5%, Generation 84.5%
|
||||
LLM Backend: HuggingFace transformers
|
||||
Model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||
|
||||
⚙️ Optimal Complexity Analysis:
|
||||
Target Recall: 90%
|
||||
Optimal Complexity: 85
|
||||
Binary Search Range: 1-128
|
||||
Non-compact Index (fast search, no recompute)
|
||||
|
||||
🚀 Performance Summary:
|
||||
Multimodal RAG: 7.758s total per query
|
||||
Search: 15.5% of total time
|
||||
Generation: 84.5% of total time
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
benchmarks/laion/
|
||||
├── setup_laion.py # Setup script
|
||||
├── evaluate_laion.py # Evaluation script
|
||||
├── README.md # This file
|
||||
└── data/ # Generated data
|
||||
├── laion_images/ # Image files (placeholder)
|
||||
├── laion_metadata.jsonl # Image metadata
|
||||
├── laion_passages.jsonl # LEANN passages
|
||||
├── laion_embeddings.npy # CLIP embeddings
|
||||
├── evaluation_queries.jsonl # Evaluation queries
|
||||
└── laion_index.leann/ # LEANN index files
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Current implementation uses dummy data for demonstration
|
||||
- For real LAION data, implement actual download logic in `setup_laion.py`
|
||||
- CLIP embeddings are randomly generated - replace with real CLIP model for production
|
||||
- Adjust `num_samples` and `num_queries` based on available resources
|
||||
- Consider using `--num-samples` during evaluation for faster testing
|
||||
725
benchmarks/laion/evaluate_laion.py
Normal file
725
benchmarks/laion/evaluate_laion.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""
|
||||
LAION Multimodal Benchmark Evaluation Script - Modular Recall-based Evaluation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from leann import LeannSearcher
|
||||
from leann_backend_hnsw import faiss
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from ..llm_utils import evaluate_multimodal_rag, load_qwen_vl_model
|
||||
|
||||
# Setup logging to reduce verbose output
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger("leann.api").setLevel(logging.WARNING)
|
||||
logging.getLogger("leann_backend_hnsw").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RecallEvaluator:
|
||||
"""Stage 2: Evaluate Recall@3 (LEANN vs FAISS baseline for multimodal retrieval)"""
|
||||
|
||||
def __init__(self, index_path: str, baseline_dir: str):
|
||||
self.index_path = index_path
|
||||
self.baseline_dir = baseline_dir
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
# Load FAISS flat baseline (image embeddings)
|
||||
baseline_index_path = os.path.join(baseline_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(baseline_dir, "metadata.pkl")
|
||||
|
||||
self.faiss_index = faiss.read_index(baseline_index_path)
|
||||
with open(metadata_path, "rb") as f:
|
||||
self.image_ids = pickle.load(f)
|
||||
print(f"📚 Loaded FAISS flat baseline with {self.faiss_index.ntotal} image vectors")
|
||||
|
||||
# Load sentence-transformers CLIP for text embedding (ViT-L/14)
|
||||
self.st_clip = SentenceTransformer("clip-ViT-L-14")
|
||||
|
||||
def evaluate_recall_at_3(
|
||||
self, captions: list[str], complexity: int = 64, recompute_embeddings: bool = True
|
||||
) -> float:
|
||||
"""Evaluate recall@3 for multimodal retrieval: caption queries -> image results"""
|
||||
recompute_str = "with recompute" if recompute_embeddings else "no recompute"
|
||||
print(f"🔍 Evaluating recall@3 with complexity={complexity} ({recompute_str})...")
|
||||
|
||||
total_recall = 0.0
|
||||
num_queries = len(captions)
|
||||
|
||||
for i, caption in enumerate(captions):
|
||||
# Get ground truth: search with FAISS flat using caption text embedding
|
||||
# Generate CLIP text embedding for caption via sentence-transformers (normalized)
|
||||
query_embedding = self.st_clip.encode(
|
||||
[caption], convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False
|
||||
).astype(np.float32)
|
||||
|
||||
# Search FAISS flat for ground truth using LEANN's modified faiss API
|
||||
n = query_embedding.shape[0] # Number of queries
|
||||
k = 3 # Number of nearest neighbors
|
||||
distances = np.zeros((n, k), dtype=np.float32)
|
||||
labels = np.zeros((n, k), dtype=np.int64)
|
||||
|
||||
self.faiss_index.search(
|
||||
n,
|
||||
faiss.swig_ptr(query_embedding),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(labels),
|
||||
)
|
||||
|
||||
# Extract the results (image IDs from FAISS)
|
||||
baseline_ids = {self.image_ids[idx] for idx in labels[0]}
|
||||
|
||||
# Search with LEANN at specified complexity (using caption as text query)
|
||||
test_results = self.searcher.search(
|
||||
caption,
|
||||
top_k=3,
|
||||
complexity=complexity,
|
||||
recompute_embeddings=recompute_embeddings,
|
||||
)
|
||||
test_ids = {result.id for result in test_results}
|
||||
|
||||
# Calculate recall@3 = |intersection| / |ground_truth|
|
||||
intersection = test_ids.intersection(baseline_ids)
|
||||
recall = len(intersection) / 3.0 # Ground truth size is 3
|
||||
total_recall += recall
|
||||
|
||||
if i < 3: # Show first few examples
|
||||
print(f" Query {i + 1}: '{caption[:50]}...' -> Recall@3: {recall:.3f}")
|
||||
print(f" FAISS ground truth: {list(baseline_ids)}")
|
||||
print(f" LEANN results (C={complexity}, {recompute_str}): {list(test_ids)}")
|
||||
print(f" Intersection: {list(intersection)}")
|
||||
|
||||
avg_recall = total_recall / num_queries
|
||||
print(f"📊 Average Recall@3: {avg_recall:.3f} ({avg_recall * 100:.1f}%)")
|
||||
return avg_recall
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if hasattr(self, "searcher"):
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
class LAIONEvaluator:
|
||||
def __init__(self, index_path: str):
|
||||
self.index_path = index_path
|
||||
self.searcher = LeannSearcher(index_path)
|
||||
|
||||
def load_queries(self, queries_file: str) -> list[str]:
|
||||
"""Load caption queries from evaluation file"""
|
||||
captions = []
|
||||
with open(queries_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
query_data = json.loads(line)
|
||||
captions.append(query_data["query"])
|
||||
|
||||
print(f"📊 Loaded {len(captions)} caption queries")
|
||||
return captions
|
||||
|
||||
def analyze_index_sizes(self) -> dict:
|
||||
"""Analyze index sizes, emphasizing .index only (exclude passages)."""
|
||||
print("📏 Analyzing index sizes (.index only)...")
|
||||
|
||||
# Get all index-related files
|
||||
index_path = Path(self.index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.stem # Remove .leann extension
|
||||
|
||||
sizes: dict[str, float] = {}
|
||||
|
||||
# Core index files
|
||||
index_file = index_dir / f"{index_name}.index"
|
||||
meta_file = index_dir / f"{index_path.name}.meta.json" # Keep .leann for meta file
|
||||
passages_file = index_dir / f"{index_path.name}.passages.jsonl" # Keep .leann for passages
|
||||
passages_idx_file = index_dir / f"{index_path.name}.passages.idx" # Keep .leann for idx
|
||||
|
||||
# Core index size (.index only)
|
||||
index_mb = index_file.stat().st_size / (1024 * 1024) if index_file.exists() else 0.0
|
||||
sizes["index_only_mb"] = index_mb
|
||||
|
||||
# Other files for reference (not counted in index_only_mb)
|
||||
sizes["metadata_mb"] = (
|
||||
meta_file.stat().st_size / (1024 * 1024) if meta_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_text_mb"] = (
|
||||
passages_file.stat().st_size / (1024 * 1024) if passages_file.exists() else 0.0
|
||||
)
|
||||
sizes["passages_index_mb"] = (
|
||||
passages_idx_file.stat().st_size / (1024 * 1024) if passages_idx_file.exists() else 0.0
|
||||
)
|
||||
|
||||
print(f" 📁 .index size: {index_mb:.1f} MB")
|
||||
if sizes["metadata_mb"]:
|
||||
print(f" 🧾 metadata: {sizes['metadata_mb']:.3f} MB")
|
||||
if sizes["passages_text_mb"] or sizes["passages_index_mb"]:
|
||||
print(
|
||||
f" (passages excluded) text: {sizes['passages_text_mb']:.1f} MB, idx: {sizes['passages_index_mb']:.1f} MB"
|
||||
)
|
||||
|
||||
return sizes
|
||||
|
||||
def create_non_compact_index_for_comparison(self, non_compact_index_path: str) -> dict:
|
||||
"""Create a non-compact index for comparison purposes"""
|
||||
print("🏗️ Building non-compact index from existing passages...")
|
||||
|
||||
# Load existing passages from current index
|
||||
from leann import LeannBuilder
|
||||
|
||||
current_index_path = Path(self.index_path)
|
||||
current_index_dir = current_index_path.parent
|
||||
current_index_name = current_index_path.name
|
||||
|
||||
# Read metadata to get passage source
|
||||
meta_path = current_index_dir / f"{current_index_name}.meta.json"
|
||||
with open(meta_path) as f:
|
||||
meta = json.load(f)
|
||||
|
||||
passage_source = meta["passage_sources"][0]
|
||||
passage_file = passage_source["path"]
|
||||
|
||||
# Convert relative path to absolute
|
||||
if not Path(passage_file).is_absolute():
|
||||
passage_file = current_index_dir / Path(passage_file).name
|
||||
|
||||
print(f"📄 Loading passages from {passage_file}...")
|
||||
|
||||
# Load CLIP embeddings
|
||||
embeddings_file = current_index_dir / "clip_image_embeddings.npy"
|
||||
embeddings = np.load(embeddings_file)
|
||||
print(f"📐 Loaded embeddings shape: {embeddings.shape}")
|
||||
|
||||
# Build non-compact index with same passages and embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
# Use CLIP text encoder (ViT-L/14) to match image embeddings (768-dim)
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
is_recompute=False, # Disable recompute (store embeddings)
|
||||
is_compact=False, # Disable compact storage
|
||||
distance_metric="cosine",
|
||||
**{
|
||||
k: v
|
||||
for k, v in meta.get("backend_kwargs", {}).items()
|
||||
if k not in ["is_recompute", "is_compact", "distance_metric"]
|
||||
},
|
||||
)
|
||||
|
||||
# Prepare ids and add passages
|
||||
ids: list[str] = []
|
||||
with open(passage_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
ids.append(str(data["id"]))
|
||||
# Ensure metadata contains the id used by the vector index
|
||||
metadata = {**data.get("metadata", {}), "id": data["id"]}
|
||||
builder.add_text(text=data["text"], metadata=metadata)
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
|
||||
# Persist a pickle for build_index_from_embeddings
|
||||
pkl_path = current_index_dir / "clip_image_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
|
||||
print(
|
||||
f"🔨 Building non-compact index at {non_compact_index_path} from precomputed embeddings..."
|
||||
)
|
||||
builder.build_index_from_embeddings(non_compact_index_path, str(pkl_path))
|
||||
|
||||
# Analyze the non-compact index size
|
||||
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||
non_compact_sizes = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_sizes["index_type"] = "non_compact"
|
||||
|
||||
return non_compact_sizes
|
||||
|
||||
def compare_index_performance(
|
||||
self, non_compact_path: str, compact_path: str, test_captions: list, complexity: int
|
||||
) -> dict:
|
||||
"""Compare performance between non-compact and compact indexes"""
|
||||
print("⚡ Comparing search performance between indexes...")
|
||||
|
||||
# Test queries
|
||||
test_queries = test_captions[:5]
|
||||
|
||||
results = {
|
||||
"non_compact": {"search_times": []},
|
||||
"compact": {"search_times": []},
|
||||
"avg_search_times": {},
|
||||
"speed_ratio": 0.0,
|
||||
}
|
||||
|
||||
# Test non-compact index (no recompute)
|
||||
print(" 🔍 Testing non-compact index (no recompute)...")
|
||||
non_compact_searcher = LeannSearcher(non_compact_path)
|
||||
|
||||
for caption in test_queries:
|
||||
start_time = time.time()
|
||||
_ = non_compact_searcher.search(
|
||||
caption, top_k=3, complexity=complexity, recompute_embeddings=False
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["non_compact"]["search_times"].append(search_time)
|
||||
|
||||
# Test compact index (with recompute)
|
||||
print(" 🔍 Testing compact index (with recompute)...")
|
||||
compact_searcher = LeannSearcher(compact_path)
|
||||
|
||||
for caption in test_queries:
|
||||
start_time = time.time()
|
||||
_ = compact_searcher.search(
|
||||
caption, top_k=3, complexity=complexity, recompute_embeddings=True
|
||||
)
|
||||
search_time = time.time() - start_time
|
||||
results["compact"]["search_times"].append(search_time)
|
||||
|
||||
# Calculate averages
|
||||
results["avg_search_times"]["non_compact"] = sum(
|
||||
results["non_compact"]["search_times"]
|
||||
) / len(results["non_compact"]["search_times"])
|
||||
results["avg_search_times"]["compact"] = sum(results["compact"]["search_times"]) / len(
|
||||
results["compact"]["search_times"]
|
||||
)
|
||||
|
||||
# Performance ratio
|
||||
if results["avg_search_times"]["compact"] > 0:
|
||||
results["speed_ratio"] = (
|
||||
results["avg_search_times"]["non_compact"] / results["avg_search_times"]["compact"]
|
||||
)
|
||||
else:
|
||||
results["speed_ratio"] = float("inf")
|
||||
|
||||
print(
|
||||
f" Non-compact (no recompute): {results['avg_search_times']['non_compact']:.3f}s avg"
|
||||
)
|
||||
print(f" Compact (with recompute): {results['avg_search_times']['compact']:.3f}s avg")
|
||||
print(f" Speed ratio: {results['speed_ratio']:.2f}x")
|
||||
|
||||
# Cleanup
|
||||
non_compact_searcher.cleanup()
|
||||
compact_searcher.cleanup()
|
||||
|
||||
return results
|
||||
|
||||
def _print_results(self, timing_metrics: dict):
|
||||
"""Print evaluation results"""
|
||||
print("\n🎯 LAION MULTIMODAL BENCHMARK RESULTS")
|
||||
print("=" * 60)
|
||||
|
||||
# Index comparison analysis (prefer .index-only view if present)
|
||||
if "current_index" in timing_metrics and "non_compact_index" in timing_metrics:
|
||||
current = timing_metrics["current_index"]
|
||||
non_compact = timing_metrics["non_compact_index"]
|
||||
|
||||
if "index_only_mb" in current and "index_only_mb" in non_compact:
|
||||
print("\n📏 Index Comparison Analysis (.index only):")
|
||||
print(f" Compact index (current): {current.get('index_only_mb', 0):.1f} MB")
|
||||
print(f" Non-compact index: {non_compact.get('index_only_mb', 0):.1f} MB")
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
# Show excluded components for reference if available
|
||||
if any(
|
||||
k in non_compact
|
||||
for k in ("passages_text_mb", "passages_index_mb", "metadata_mb")
|
||||
):
|
||||
print(" (passages excluded in totals, shown for reference):")
|
||||
print(
|
||||
f" - Passages text: {non_compact.get('passages_text_mb', 0):.1f} MB, "
|
||||
f"Passages index: {non_compact.get('passages_index_mb', 0):.1f} MB, "
|
||||
f"Metadata: {non_compact.get('metadata_mb', 0):.3f} MB"
|
||||
)
|
||||
else:
|
||||
# Fallback to legacy totals if running with older metrics
|
||||
print("\n📏 Index Comparison Analysis:")
|
||||
print(
|
||||
f" Compact index (current): {current.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Non-compact index (with embeddings): {non_compact.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Storage saving by compact: {timing_metrics.get('storage_saving_percent', 0):.1f}%"
|
||||
)
|
||||
print(" Component breakdown (non-compact):")
|
||||
print(f" - Main index: {non_compact.get('index', 0):.1f} MB")
|
||||
print(f" - Passages text: {non_compact.get('passages_text', 0):.1f} MB")
|
||||
print(f" - Passages index: {non_compact.get('passages_index', 0):.1f} MB")
|
||||
print(f" - Metadata: {non_compact.get('metadata', 0):.1f} MB")
|
||||
|
||||
# Performance comparison
|
||||
if "performance_comparison" in timing_metrics:
|
||||
perf = timing_metrics["performance_comparison"]
|
||||
print("\n⚡ Performance Comparison:")
|
||||
print(
|
||||
f" Non-compact (no recompute): {perf.get('avg_search_times', {}).get('non_compact', 0):.3f}s avg"
|
||||
)
|
||||
print(
|
||||
f" Compact (with recompute): {perf.get('avg_search_times', {}).get('compact', 0):.3f}s avg"
|
||||
)
|
||||
print(f" Speed ratio: {perf.get('speed_ratio', 0):.2f}x")
|
||||
|
||||
# Legacy single index analysis (fallback)
|
||||
if "total_with_embeddings" in timing_metrics and "current_index" not in timing_metrics:
|
||||
print("\n📏 Index Size Analysis:")
|
||||
print(
|
||||
f" Index with embeddings: {timing_metrics.get('total_with_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(
|
||||
f" Estimated pruned index: {timing_metrics.get('total_without_embeddings', 0):.1f} MB"
|
||||
)
|
||||
print(f" Compression ratio: {timing_metrics.get('compression_ratio', 0):.2f}x")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
if self.searcher:
|
||||
self.searcher.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="LAION Multimodal Benchmark Evaluation")
|
||||
parser.add_argument("--index", required=True, help="Path to LEANN index")
|
||||
parser.add_argument(
|
||||
"--queries", default="data/evaluation_queries.jsonl", help="Path to evaluation queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stage",
|
||||
choices=["2", "3", "4", "5", "all"],
|
||||
default="all",
|
||||
help="Which stage to run (2=recall, 3=complexity, 4=index comparison, 5=generation)",
|
||||
)
|
||||
parser.add_argument("--complexity", type=int, default=None, help="Complexity for search")
|
||||
parser.add_argument("--baseline-dir", default="baseline", help="Baseline output directory")
|
||||
parser.add_argument("--output", help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--llm-backend",
|
||||
choices=["hf"],
|
||||
default="hf",
|
||||
help="LLM backend (Qwen2.5-VL only supports HF)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name", default="Qwen/Qwen2.5-VL-7B-Instruct", help="Multimodal model name"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Check if baseline exists
|
||||
baseline_index_path = os.path.join(args.baseline_dir, "faiss_flat.index")
|
||||
if not os.path.exists(baseline_index_path):
|
||||
print(f"❌ FAISS baseline not found at {baseline_index_path}")
|
||||
print("💡 Please run setup_laion.py first to build the baseline")
|
||||
exit(1)
|
||||
|
||||
if args.stage == "2" or args.stage == "all":
|
||||
# Stage 2: Recall@3 evaluation
|
||||
print("🚀 Starting Stage 2: Recall@3 evaluation for multimodal retrieval")
|
||||
|
||||
evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
|
||||
# Load caption queries for testing
|
||||
laion_evaluator = LAIONEvaluator(args.index)
|
||||
captions = laion_evaluator.load_queries(args.queries)
|
||||
|
||||
# Test with queries for robust measurement
|
||||
test_captions = captions[:100] # Use subset for speed
|
||||
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||
|
||||
# Test with complexity 64
|
||||
complexity = 64
|
||||
recall = evaluator.evaluate_recall_at_3(test_captions, complexity)
|
||||
print(f"📈 Recall@3 at complexity {complexity}: {recall * 100:.1f}%")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 2 completed!\n")
|
||||
|
||||
# Shared non-compact index path for Stage 3 and 4
|
||||
non_compact_index_path = args.index.replace(".leann", "_noncompact.leann")
|
||||
complexity = args.complexity
|
||||
|
||||
if args.stage == "3" or args.stage == "all":
|
||||
# Stage 3: Binary search for 90% recall complexity
|
||||
print("🚀 Starting Stage 3: Binary search for 90% recall complexity")
|
||||
print(
|
||||
"💡 Creating non-compact index for fast binary search with recompute_embeddings=False"
|
||||
)
|
||||
|
||||
# Create non-compact index for binary search
|
||||
print("🏗️ Creating non-compact index for binary search...")
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
evaluator.create_non_compact_index_for_comparison(non_compact_index_path)
|
||||
|
||||
# Use non-compact index for binary search
|
||||
binary_search_evaluator = RecallEvaluator(non_compact_index_path, args.baseline_dir)
|
||||
|
||||
# Load caption queries for testing
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
|
||||
# Use subset for robust measurement
|
||||
test_captions = captions[:50] # Smaller subset for binary search speed
|
||||
print(f"🧪 Testing with {len(test_captions)} caption queries")
|
||||
|
||||
# Binary search for 90% recall complexity
|
||||
target_recall = 0.9
|
||||
min_complexity, max_complexity = 1, 128
|
||||
|
||||
print(f"🔍 Binary search for {target_recall * 100}% recall complexity...")
|
||||
print(f"Search range: {min_complexity} to {max_complexity}")
|
||||
|
||||
best_complexity = None
|
||||
best_recall = 0.0
|
||||
|
||||
while min_complexity <= max_complexity:
|
||||
mid_complexity = (min_complexity + max_complexity) // 2
|
||||
|
||||
print(
|
||||
f"\n🧪 Testing complexity {mid_complexity} (no recompute, non-compact index)..."
|
||||
)
|
||||
# Use recompute_embeddings=False on non-compact index for fast binary search
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_captions, mid_complexity, recompute_embeddings=False
|
||||
)
|
||||
|
||||
print(
|
||||
f" Complexity {mid_complexity}: Recall@3 = {recall:.3f} ({recall * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if recall >= target_recall:
|
||||
best_complexity = mid_complexity
|
||||
best_recall = recall
|
||||
max_complexity = mid_complexity - 1
|
||||
print(" ✅ Target reached! Searching for lower complexity...")
|
||||
else:
|
||||
min_complexity = mid_complexity + 1
|
||||
print(" ❌ Below target. Searching for higher complexity...")
|
||||
|
||||
if best_complexity is not None:
|
||||
print("\n🎯 Optimal complexity found!")
|
||||
print(f" Complexity: {best_complexity}")
|
||||
print(f" Recall@3: {best_recall:.3f} ({best_recall * 100:.1f}%)")
|
||||
|
||||
# Test a few complexities around the optimal one for verification
|
||||
print("\n🔬 Verification test around optimal complexity:")
|
||||
verification_complexities = [
|
||||
max(1, best_complexity - 2),
|
||||
max(1, best_complexity - 1),
|
||||
best_complexity,
|
||||
best_complexity + 1,
|
||||
best_complexity + 2,
|
||||
]
|
||||
|
||||
for complexity in verification_complexities:
|
||||
if complexity <= 512: # reasonable upper bound
|
||||
recall = binary_search_evaluator.evaluate_recall_at_3(
|
||||
test_captions, complexity, recompute_embeddings=False
|
||||
)
|
||||
status = "✅" if recall >= target_recall else "❌"
|
||||
print(f" {status} Complexity {complexity:3d}: {recall * 100:5.1f}%")
|
||||
|
||||
# Now test the optimal complexity with compact index and recompute for comparison
|
||||
print(
|
||||
f"\n🔄 Testing optimal complexity {best_complexity} on compact index WITH recompute..."
|
||||
)
|
||||
compact_evaluator = RecallEvaluator(args.index, args.baseline_dir)
|
||||
recall_with_recompute = compact_evaluator.evaluate_recall_at_3(
|
||||
test_captions[:10], best_complexity, recompute_embeddings=True
|
||||
)
|
||||
print(
|
||||
f" ✅ Complexity {best_complexity} (compact index with recompute): {recall_with_recompute * 100:.1f}%"
|
||||
)
|
||||
complexity = best_complexity
|
||||
print(
|
||||
f" 📊 Recall difference: {abs(best_recall - recall_with_recompute) * 100:.2f}%"
|
||||
)
|
||||
compact_evaluator.cleanup()
|
||||
else:
|
||||
print(f"\n❌ Could not find complexity achieving {target_recall * 100}% recall")
|
||||
print("All tested complexities were below target.")
|
||||
|
||||
# Cleanup evaluators (keep non-compact index for Stage 4)
|
||||
binary_search_evaluator.cleanup()
|
||||
evaluator.cleanup()
|
||||
|
||||
print("✅ Stage 3 completed! Non-compact index saved for Stage 4.\n")
|
||||
|
||||
if args.stage == "4" or args.stage == "all":
|
||||
# Stage 4: Index comparison (without LLM generation)
|
||||
print("🚀 Starting Stage 4: Index comparison analysis")
|
||||
|
||||
# Use LAION evaluator for index comparison
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
|
||||
# Load caption queries
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
|
||||
# Step 1: Analyze current (compact) index
|
||||
print("\n📏 Analyzing current index (compact, pruned)...")
|
||||
compact_size_metrics = evaluator.analyze_index_sizes()
|
||||
compact_size_metrics["index_type"] = "compact"
|
||||
|
||||
# Step 2: Use existing non-compact index or create if needed
|
||||
if Path(non_compact_index_path).exists():
|
||||
print(
|
||||
f"\n📁 Using existing non-compact index from Stage 3: {non_compact_index_path}"
|
||||
)
|
||||
temp_evaluator = LAIONEvaluator(non_compact_index_path)
|
||||
non_compact_size_metrics = temp_evaluator.analyze_index_sizes()
|
||||
non_compact_size_metrics["index_type"] = "non_compact"
|
||||
else:
|
||||
print("\n🏗️ Creating non-compact index (with embeddings) for comparison...")
|
||||
non_compact_size_metrics = evaluator.create_non_compact_index_for_comparison(
|
||||
non_compact_index_path
|
||||
)
|
||||
|
||||
# Step 3: Compare index sizes (.index only)
|
||||
print("\n📊 Index size comparison (.index only):")
|
||||
print(
|
||||
f" Compact index (current): {compact_size_metrics.get('index_only_mb', 0):.1f} MB"
|
||||
)
|
||||
print(f" Non-compact index: {non_compact_size_metrics.get('index_only_mb', 0):.1f} MB")
|
||||
|
||||
storage_saving = 0.0
|
||||
if non_compact_size_metrics.get("index_only_mb", 0) > 0:
|
||||
storage_saving = (
|
||||
(
|
||||
non_compact_size_metrics.get("index_only_mb", 0)
|
||||
- compact_size_metrics.get("index_only_mb", 0)
|
||||
)
|
||||
/ non_compact_size_metrics.get("index_only_mb", 1)
|
||||
* 100
|
||||
)
|
||||
print(f" Storage saving by compact: {storage_saving:.1f}%")
|
||||
|
||||
# Step 4: Performance comparison between the two indexes
|
||||
if complexity is None:
|
||||
raise ValueError("Complexity is required for index comparison")
|
||||
|
||||
print("\n⚡ Performance comparison between indexes...")
|
||||
performance_metrics = evaluator.compare_index_performance(
|
||||
non_compact_index_path, args.index, captions[:10], complexity=complexity
|
||||
)
|
||||
|
||||
# Combine all metrics
|
||||
combined_metrics = {
|
||||
"current_index": compact_size_metrics,
|
||||
"non_compact_index": non_compact_size_metrics,
|
||||
"performance_comparison": performance_metrics,
|
||||
"storage_saving_percent": storage_saving,
|
||||
}
|
||||
|
||||
# Print comprehensive results
|
||||
evaluator._print_results(combined_metrics)
|
||||
|
||||
# Save results if requested
|
||||
if args.output:
|
||||
print(f"\n💾 Saving results to {args.output}...")
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(combined_metrics, f, indent=2, default=str)
|
||||
print(f"✅ Results saved to {args.output}")
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 4 completed!\n")
|
||||
|
||||
if args.stage in ("5", "all"):
|
||||
print("🚀 Starting Stage 5: Multimodal generation with Qwen2.5-VL")
|
||||
evaluator = LAIONEvaluator(args.index)
|
||||
captions = evaluator.load_queries(args.queries)
|
||||
test_captions = captions[: min(20, len(captions))] # Use subset for generation
|
||||
|
||||
print(f"🧪 Testing multimodal generation with {len(test_captions)} queries")
|
||||
|
||||
# Load Qwen2.5-VL model
|
||||
try:
|
||||
print("Loading Qwen2.5-VL model...")
|
||||
processor, model = load_qwen_vl_model(args.model_name)
|
||||
|
||||
# Run multimodal generation evaluation
|
||||
complexity = args.complexity or 64
|
||||
gen_results = evaluate_multimodal_rag(
|
||||
evaluator.searcher,
|
||||
test_captions,
|
||||
processor=processor,
|
||||
model=model,
|
||||
complexity=complexity,
|
||||
)
|
||||
|
||||
print("\n📊 Multimodal Generation Results:")
|
||||
print(f" Total Queries: {len(test_captions)}")
|
||||
print(f" Avg Search Time: {gen_results['avg_search_time']:.3f}s")
|
||||
print(f" Avg Generation Time: {gen_results['avg_generation_time']:.3f}s")
|
||||
total_time = gen_results["avg_search_time"] + gen_results["avg_generation_time"]
|
||||
search_pct = (gen_results["avg_search_time"] / total_time) * 100
|
||||
gen_pct = (gen_results["avg_generation_time"] / total_time) * 100
|
||||
print(f" Time Distribution: Search {search_pct:.1f}%, Generation {gen_pct:.1f}%")
|
||||
print(" LLM Backend: HuggingFace transformers")
|
||||
print(f" Model: {args.model_name}")
|
||||
|
||||
# Show sample results
|
||||
print("\n📝 Sample Multimodal Generations:")
|
||||
for i, response in enumerate(gen_results["results"][:3]):
|
||||
# Handle both string and dict formats for captions
|
||||
if isinstance(test_captions[i], dict):
|
||||
caption_text = test_captions[i].get("query", str(test_captions[i]))
|
||||
else:
|
||||
caption_text = str(test_captions[i])
|
||||
print(f" Query {i + 1}: {caption_text[:60]}...")
|
||||
print(f" Response {i + 1}: {response[:100]}...")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Multimodal generation evaluation failed: {e}")
|
||||
print("💡 Make sure transformers and Qwen2.5-VL are installed")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
evaluator.cleanup()
|
||||
print("✅ Stage 5 completed!\n")
|
||||
|
||||
if args.stage == "all":
|
||||
print("🎉 All evaluation stages completed successfully!")
|
||||
print("\n📋 Summary:")
|
||||
print(" Stage 2: ✅ Multimodal Recall@3 evaluation completed")
|
||||
print(" Stage 3: ✅ Optimal complexity found")
|
||||
print(" Stage 4: ✅ Index comparison analysis completed")
|
||||
print(" Stage 5: ✅ Multimodal generation evaluation completed")
|
||||
print("\n🔧 Recommended next steps:")
|
||||
print(" - Use optimal complexity for best speed/accuracy balance")
|
||||
print(" - Review index comparison for storage vs performance tradeoffs")
|
||||
|
||||
# Clean up non-compact index after all stages complete
|
||||
print("\n🧹 Cleaning up temporary non-compact index...")
|
||||
if Path(non_compact_index_path).exists():
|
||||
temp_index_dir = Path(non_compact_index_path).parent
|
||||
temp_index_name = Path(non_compact_index_path).name
|
||||
for temp_file in temp_index_dir.glob(f"{temp_index_name}*"):
|
||||
temp_file.unlink()
|
||||
print(f"✅ Cleaned up {non_compact_index_path}")
|
||||
else:
|
||||
print("📝 No temporary index to clean up")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Evaluation interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Stage {args.stage} failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
576
benchmarks/laion/setup_laion.py
Normal file
576
benchmarks/laion/setup_laion.py
Normal file
@@ -0,0 +1,576 @@
|
||||
"""
|
||||
LAION Multimodal Benchmark Setup Script
|
||||
Downloads LAION subset and builds LEANN index with sentence embeddings
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from leann import LeannBuilder
|
||||
from PIL import Image
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class LAIONSetup:
|
||||
def __init__(self, data_dir: str = "data"):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.images_dir = self.data_dir / "laion_images"
|
||||
self.metadata_file = self.data_dir / "laion_metadata.jsonl"
|
||||
|
||||
# Create directories
|
||||
self.data_dir.mkdir(exist_ok=True)
|
||||
self.images_dir.mkdir(exist_ok=True)
|
||||
|
||||
async def download_single_image(self, session, sample_data, semaphore, progress_bar):
|
||||
"""Download a single image asynchronously"""
|
||||
async with semaphore: # Limit concurrent downloads
|
||||
try:
|
||||
image_url = sample_data["url"]
|
||||
image_path = sample_data["image_path"]
|
||||
|
||||
# Skip if already exists
|
||||
if os.path.exists(image_path):
|
||||
progress_bar.update(1)
|
||||
return sample_data
|
||||
|
||||
async with session.get(image_url, timeout=10) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
|
||||
# Verify it's a valid image
|
||||
try:
|
||||
img = Image.open(io.BytesIO(content))
|
||||
img = img.convert("RGB")
|
||||
img.save(image_path, "JPEG")
|
||||
progress_bar.update(1)
|
||||
return sample_data
|
||||
except Exception:
|
||||
progress_bar.update(1)
|
||||
return None # Skip invalid images
|
||||
else:
|
||||
progress_bar.update(1)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
progress_bar.update(1)
|
||||
return None
|
||||
|
||||
def download_laion_subset(self, num_samples: int = 1000):
|
||||
"""Download LAION subset from HuggingFace datasets with async parallel downloading"""
|
||||
print(f"📥 Downloading LAION subset ({num_samples} samples)...")
|
||||
|
||||
# Load LAION-400M subset from HuggingFace
|
||||
print("🤗 Loading from HuggingFace datasets...")
|
||||
dataset = load_dataset("laion/laion400m", split="train", streaming=True)
|
||||
|
||||
# Collect sample metadata first (fast)
|
||||
print("📋 Collecting sample metadata...")
|
||||
candidates = []
|
||||
for sample in dataset:
|
||||
if len(candidates) >= num_samples * 3: # Get 3x more candidates in case some fail
|
||||
break
|
||||
|
||||
image_url = sample.get("url", "")
|
||||
caption = sample.get("caption", "")
|
||||
|
||||
if not image_url or not caption:
|
||||
continue
|
||||
|
||||
image_filename = f"laion_{len(candidates):06d}.jpg"
|
||||
image_path = self.images_dir / image_filename
|
||||
|
||||
candidate = {
|
||||
"id": f"laion_{len(candidates):06d}",
|
||||
"url": image_url,
|
||||
"caption": caption,
|
||||
"image_path": str(image_path),
|
||||
"width": sample.get("original_width", 512),
|
||||
"height": sample.get("original_height", 512),
|
||||
"similarity": sample.get("similarity", 0.0),
|
||||
}
|
||||
candidates.append(candidate)
|
||||
|
||||
print(
|
||||
f"📊 Collected {len(candidates)} candidates, downloading {num_samples} in parallel..."
|
||||
)
|
||||
|
||||
# Download images in parallel
|
||||
async def download_batch():
|
||||
semaphore = asyncio.Semaphore(20) # Limit to 20 concurrent downloads
|
||||
connector = aiohttp.TCPConnector(limit=100, limit_per_host=20)
|
||||
timeout = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
progress_bar = tqdm(total=len(candidates[: num_samples * 2]), desc="Downloading images")
|
||||
|
||||
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
||||
tasks = []
|
||||
for candidate in candidates[: num_samples * 2]: # Try 2x more than needed
|
||||
task = self.download_single_image(session, candidate, semaphore, progress_bar)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all downloads
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
progress_bar.close()
|
||||
|
||||
# Filter successful downloads
|
||||
successful = [r for r in results if r is not None and not isinstance(r, Exception)]
|
||||
return successful[:num_samples]
|
||||
|
||||
# Run async download
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
samples = loop.run_until_complete(download_batch())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Save metadata
|
||||
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
||||
for sample in samples:
|
||||
f.write(json.dumps(sample) + "\n")
|
||||
|
||||
print(f"✅ Downloaded {len(samples)} real LAION samples with async parallel downloading")
|
||||
return samples
|
||||
|
||||
def generate_clip_image_embeddings(self, samples: list[dict]):
|
||||
"""Generate CLIP image embeddings for downloaded images"""
|
||||
print("🔍 Generating CLIP image embeddings...")
|
||||
|
||||
# Load sentence-transformers CLIP (ViT-L/14, 768-dim) for image embeddings
|
||||
# This single model can encode both images and text.
|
||||
model = SentenceTransformer("clip-ViT-L-14")
|
||||
|
||||
embeddings = []
|
||||
valid_samples = []
|
||||
|
||||
for sample in tqdm(samples, desc="Processing images"):
|
||||
try:
|
||||
# Load image
|
||||
image_path = sample["image_path"]
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
|
||||
# Encode image to 768-dim embedding via sentence-transformers (normalized)
|
||||
vec = model.encode(
|
||||
[image],
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
batch_size=1,
|
||||
show_progress_bar=False,
|
||||
)[0]
|
||||
embeddings.append(vec.astype(np.float32))
|
||||
valid_samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Failed to process {sample['id']}: {e}")
|
||||
# Skip invalid images
|
||||
|
||||
embeddings = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
# Save embeddings
|
||||
embeddings_file = self.data_dir / "clip_image_embeddings.npy"
|
||||
np.save(embeddings_file, embeddings)
|
||||
print(f"✅ Generated {len(embeddings)} image embeddings, shape: {embeddings.shape}")
|
||||
|
||||
return embeddings, valid_samples
|
||||
|
||||
def build_faiss_baseline(
|
||||
self, embeddings: np.ndarray, samples: list[dict], output_dir: str = "baseline"
|
||||
):
|
||||
"""Build FAISS flat baseline using CLIP image embeddings"""
|
||||
print("🔨 Building FAISS Flat baseline...")
|
||||
|
||||
from leann_backend_hnsw import faiss
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
baseline_path = os.path.join(output_dir, "faiss_flat.index")
|
||||
metadata_path = os.path.join(output_dir, "metadata.pkl")
|
||||
|
||||
if os.path.exists(baseline_path) and os.path.exists(metadata_path):
|
||||
print(f"✅ Baseline already exists at {baseline_path}")
|
||||
return baseline_path
|
||||
|
||||
# Extract image IDs (must be present)
|
||||
if not samples or "id" not in samples[0]:
|
||||
raise KeyError("samples missing 'id' field for FAISS baseline")
|
||||
image_ids: list[str] = [str(sample["id"]) for sample in samples]
|
||||
|
||||
print(f"📐 Embedding shape: {embeddings.shape}")
|
||||
print(f"📄 Processing {len(image_ids)} images")
|
||||
|
||||
# Build FAISS flat index
|
||||
print("🏗️ Building FAISS IndexFlatIP...")
|
||||
dimension = embeddings.shape[1]
|
||||
index = faiss.IndexFlatIP(dimension)
|
||||
|
||||
# Add embeddings to flat index
|
||||
embeddings_f32 = embeddings.astype(np.float32)
|
||||
index.add(embeddings_f32.shape[0], faiss.swig_ptr(embeddings_f32))
|
||||
|
||||
# Save index and metadata
|
||||
faiss.write_index(index, baseline_path)
|
||||
with open(metadata_path, "wb") as f:
|
||||
pickle.dump(image_ids, f)
|
||||
|
||||
print(f"✅ FAISS baseline saved to {baseline_path}")
|
||||
print(f"✅ Metadata saved to {metadata_path}")
|
||||
print(f"📊 Total vectors: {index.ntotal}")
|
||||
|
||||
return baseline_path
|
||||
|
||||
def create_leann_passages(self, samples: list[dict]):
|
||||
"""Create LEANN-compatible passages from LAION data"""
|
||||
print("📝 Creating LEANN passages...")
|
||||
|
||||
passages_file = self.data_dir / "laion_passages.jsonl"
|
||||
|
||||
with open(passages_file, "w", encoding="utf-8") as f:
|
||||
for i, sample in enumerate(samples):
|
||||
passage = {
|
||||
"id": sample["id"],
|
||||
"text": sample["caption"], # Use caption as searchable text
|
||||
"metadata": {
|
||||
"image_url": sample["url"],
|
||||
"image_path": sample.get("image_path", ""),
|
||||
"width": sample["width"],
|
||||
"height": sample["height"],
|
||||
"similarity": sample["similarity"],
|
||||
"image_index": i, # Index for embedding lookup
|
||||
},
|
||||
}
|
||||
f.write(json.dumps(passage) + "\n")
|
||||
|
||||
print(f"✅ Created {len(samples)} passages")
|
||||
return passages_file
|
||||
|
||||
def build_compact_index(
|
||||
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""Build compact LEANN index with CLIP embeddings (recompute=True, compact=True)"""
|
||||
print(f"🏗️ Building compact LEANN index with {backend} backend...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Save CLIP embeddings (npy) and also a pickle with (ids, embeddings)
|
||||
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||
np.save(npy_path, embeddings)
|
||||
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||
|
||||
# Prepare ids in the same order as passages_file (matches embeddings order)
|
||||
ids: list[str] = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
rec = json.loads(line)
|
||||
ids.append(str(rec["id"]))
|
||||
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
|
||||
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||
|
||||
# Initialize builder - compact with recompute
|
||||
# Note: For multimodal case, we need to handle embeddings differently
|
||||
# Let's try using sentence-transformers mode but with custom embeddings
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
# Use CLIP text encoder (ViT-L/14) to match image space (768-dim)
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
# HNSW params (or forwarded to chosen backend)
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
# Compact/pruned with recompute at query time
|
||||
is_recompute=True,
|
||||
is_compact=True,
|
||||
distance_metric="cosine", # CLIP uses normalized vectors; cosine is appropriate
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Add passages (text + metadata)
|
||||
print("📚 Adding passages...")
|
||||
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||
|
||||
print(f"🔨 Building compact index at {index_path} from precomputed embeddings...")
|
||||
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
print(f"✅ Compact index built in {build_time:.2f}s")
|
||||
|
||||
# Analyze index size
|
||||
self._analyze_index_size(index_path)
|
||||
|
||||
return index_path
|
||||
|
||||
def build_non_compact_index(
|
||||
self, passages_file: Path, embeddings: np.ndarray, index_path: str, backend: str = "hnsw"
|
||||
):
|
||||
"""Build non-compact LEANN index with CLIP embeddings (recompute=False, compact=False)"""
|
||||
print(f"🏗️ Building non-compact LEANN index with {backend} backend...")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Ensure embeddings are saved (npy + pickle)
|
||||
npy_path = self.data_dir / "clip_image_embeddings.npy"
|
||||
if not npy_path.exists():
|
||||
np.save(npy_path, embeddings)
|
||||
print(f"💾 Saved CLIP embeddings to {npy_path}")
|
||||
# Prepare ids in same order as passages_file
|
||||
ids: list[str] = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
rec = json.loads(line)
|
||||
ids.append(str(rec["id"]))
|
||||
if len(ids) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"IDs count ({len(ids)}) does not match embeddings ({embeddings.shape[0]})."
|
||||
)
|
||||
pkl_path = self.data_dir / "clip_image_embeddings.pkl"
|
||||
if not pkl_path.exists():
|
||||
with open(pkl_path, "wb") as pf:
|
||||
pickle.dump((ids, embeddings.astype(np.float32)), pf)
|
||||
print(f"💾 Saved (ids, embeddings) pickle to {pkl_path}")
|
||||
|
||||
# Initialize builder - non-compact without recompute
|
||||
builder = LeannBuilder(
|
||||
backend_name=backend,
|
||||
embedding_model="clip-ViT-L-14",
|
||||
embedding_mode="sentence-transformers",
|
||||
graph_degree=32,
|
||||
complexity=64,
|
||||
is_recompute=False, # Store embeddings (no recompute needed)
|
||||
is_compact=False, # Store full index (not pruned)
|
||||
distance_metric="cosine",
|
||||
num_threads=4,
|
||||
)
|
||||
|
||||
# Add passages - embeddings will be loaded from file
|
||||
print("📚 Adding passages...")
|
||||
self._add_passages_with_embeddings(builder, passages_file, embeddings)
|
||||
|
||||
print(f"🔨 Building non-compact index at {index_path} from precomputed embeddings...")
|
||||
builder.build_index_from_embeddings(index_path, str(pkl_path))
|
||||
|
||||
build_time = time.time() - start_time
|
||||
print(f"✅ Non-compact index built in {build_time:.2f}s")
|
||||
|
||||
# Analyze index size
|
||||
self._analyze_index_size(index_path)
|
||||
|
||||
return index_path
|
||||
|
||||
def _add_passages_with_embeddings(self, builder, passages_file: Path, embeddings: np.ndarray):
|
||||
"""Helper to add passages with pre-computed CLIP embeddings"""
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in tqdm(f, desc="Adding passages"):
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
|
||||
# Add image metadata - LEANN will handle embeddings separately
|
||||
# Note: We store image metadata and caption text for searchability
|
||||
# Important: ensure passage ID in metadata matches vector ID
|
||||
builder.add_text(
|
||||
text=passage["text"], # Image caption for searchability
|
||||
metadata={**passage["metadata"], "id": passage["id"]},
|
||||
)
|
||||
|
||||
def _analyze_index_size(self, index_path: str):
|
||||
"""Analyze index file sizes"""
|
||||
print("📏 Analyzing index sizes...")
|
||||
|
||||
index_path = Path(index_path)
|
||||
index_dir = index_path.parent
|
||||
index_name = index_path.name # e.g., laion_index.leann
|
||||
index_prefix = index_path.stem # e.g., laion_index
|
||||
|
||||
files = [
|
||||
(f"{index_prefix}.index", ".index", "core"),
|
||||
(f"{index_name}.meta.json", ".meta.json", "core"),
|
||||
(f"{index_name}.ids.txt", ".ids.txt", "core"),
|
||||
(f"{index_name}.passages.jsonl", ".passages.jsonl", "passages"),
|
||||
(f"{index_name}.passages.idx", ".passages.idx", "passages"),
|
||||
]
|
||||
|
||||
def _fmt_size(bytes_val: int) -> str:
|
||||
if bytes_val < 1024:
|
||||
return f"{bytes_val} B"
|
||||
kb = bytes_val / 1024
|
||||
if kb < 1024:
|
||||
return f"{kb:.1f} KB"
|
||||
mb = kb / 1024
|
||||
if mb < 1024:
|
||||
return f"{mb:.2f} MB"
|
||||
gb = mb / 1024
|
||||
return f"{gb:.2f} GB"
|
||||
|
||||
total_index_only_mb = 0.0
|
||||
total_all_mb = 0.0
|
||||
for filename, label, group in files:
|
||||
file_path = index_dir / filename
|
||||
if file_path.exists():
|
||||
size_bytes = file_path.stat().st_size
|
||||
print(f" {label}: {_fmt_size(size_bytes)}")
|
||||
size_mb = size_bytes / (1024 * 1024)
|
||||
total_all_mb += size_mb
|
||||
if group == "core":
|
||||
total_index_only_mb += size_mb
|
||||
else:
|
||||
print(f" {label}: (missing)")
|
||||
print(f" Total (index only, exclude passages): {total_index_only_mb:.2f} MB")
|
||||
print(f" Total (including passages): {total_all_mb:.2f} MB")
|
||||
|
||||
def create_evaluation_queries(self, samples: list[dict], num_queries: int = 200):
|
||||
"""Create evaluation queries from captions"""
|
||||
print(f"📝 Creating {num_queries} evaluation queries...")
|
||||
|
||||
# Sample random captions as queries
|
||||
import random
|
||||
|
||||
random.seed(42) # For reproducibility
|
||||
|
||||
query_samples = random.sample(samples, min(num_queries, len(samples)))
|
||||
|
||||
queries_file = self.data_dir / "evaluation_queries.jsonl"
|
||||
with open(queries_file, "w", encoding="utf-8") as f:
|
||||
for sample in query_samples:
|
||||
query = {
|
||||
"id": sample["id"],
|
||||
"query": sample["caption"],
|
||||
"ground_truth_id": sample["id"], # For potential recall evaluation
|
||||
}
|
||||
f.write(json.dumps(query) + "\n")
|
||||
|
||||
print(f"✅ Created {len(query_samples)} evaluation queries")
|
||||
return queries_file
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Setup LAION Multimodal Benchmark")
|
||||
parser.add_argument("--data-dir", default="data", help="Data directory")
|
||||
parser.add_argument("--num-samples", type=int, default=1000, help="Number of LAION samples")
|
||||
parser.add_argument("--num-queries", type=int, default=50, help="Number of evaluation queries")
|
||||
parser.add_argument("--index-path", default="data/laion_index.leann", help="Output index path")
|
||||
parser.add_argument(
|
||||
"--backend", default="hnsw", choices=["hnsw", "diskann"], help="LEANN backend"
|
||||
)
|
||||
parser.add_argument("--skip-download", action="store_true", help="Skip LAION dataset download")
|
||||
parser.add_argument("--skip-build", action="store_true", help="Skip index building")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🚀 Setting up LAION Multimodal Benchmark")
|
||||
print("=" * 50)
|
||||
|
||||
try:
|
||||
# Initialize setup
|
||||
setup = LAIONSetup(args.data_dir)
|
||||
|
||||
# Step 1: Download LAION subset
|
||||
if not args.skip_download:
|
||||
print("\n📦 Step 1: Download LAION subset")
|
||||
samples = setup.download_laion_subset(args.num_samples)
|
||||
|
||||
# Step 2: Generate CLIP image embeddings
|
||||
print("\n🔍 Step 2: Generate CLIP image embeddings")
|
||||
embeddings, valid_samples = setup.generate_clip_image_embeddings(samples)
|
||||
|
||||
# Step 3: Create LEANN passages (image metadata with embeddings)
|
||||
print("\n📝 Step 3: Create LEANN passages")
|
||||
passages_file = setup.create_leann_passages(valid_samples)
|
||||
else:
|
||||
print("⏭️ Skipping LAION dataset download")
|
||||
# Load existing data
|
||||
passages_file = setup.data_dir / "laion_passages.jsonl"
|
||||
embeddings_file = setup.data_dir / "clip_image_embeddings.npy"
|
||||
|
||||
if not passages_file.exists() or not embeddings_file.exists():
|
||||
raise FileNotFoundError(
|
||||
"Passages or embeddings file not found. Run without --skip-download first."
|
||||
)
|
||||
|
||||
embeddings = np.load(embeddings_file)
|
||||
print(f"📊 Loaded {len(embeddings)} embeddings from {embeddings_file}")
|
||||
|
||||
# Step 4: Build LEANN indexes (both compact and non-compact)
|
||||
if not args.skip_build:
|
||||
print("\n🏗️ Step 4: Build LEANN indexes with CLIP image embeddings")
|
||||
|
||||
# Build compact index (production mode - small, recompute required)
|
||||
compact_index_path = args.index_path
|
||||
print(f"Building compact index: {compact_index_path}")
|
||||
setup.build_compact_index(passages_file, embeddings, compact_index_path, args.backend)
|
||||
|
||||
# Build non-compact index (comparison mode - large, fast search)
|
||||
non_compact_index_path = args.index_path.replace(".leann", "_noncompact.leann")
|
||||
print(f"Building non-compact index: {non_compact_index_path}")
|
||||
setup.build_non_compact_index(
|
||||
passages_file, embeddings, non_compact_index_path, args.backend
|
||||
)
|
||||
|
||||
# Step 5: Build FAISS flat baseline
|
||||
print("\n🔨 Step 5: Build FAISS flat baseline")
|
||||
if not args.skip_download:
|
||||
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||
else:
|
||||
# Load valid_samples from passages file for FAISS baseline
|
||||
valid_samples = []
|
||||
with open(passages_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
passage = json.loads(line)
|
||||
valid_samples.append({"id": passage["id"], "caption": passage["text"]})
|
||||
baseline_path = setup.build_faiss_baseline(embeddings, valid_samples)
|
||||
|
||||
# Step 6: Create evaluation queries
|
||||
print("\n📝 Step 6: Create evaluation queries")
|
||||
queries_file = setup.create_evaluation_queries(valid_samples, args.num_queries)
|
||||
else:
|
||||
print("⏭️ Skipping index building")
|
||||
baseline_path = "data/baseline/faiss_index.bin"
|
||||
queries_file = setup.data_dir / "evaluation_queries.jsonl"
|
||||
|
||||
print("\n🎉 Setup completed successfully!")
|
||||
print("📊 Summary:")
|
||||
if not args.skip_download:
|
||||
print(f" Downloaded samples: {len(samples)}")
|
||||
print(f" Valid samples with embeddings: {len(valid_samples)}")
|
||||
else:
|
||||
print(f" Loaded {len(embeddings)} embeddings")
|
||||
|
||||
if not args.skip_build:
|
||||
print(f" Compact index: {compact_index_path}")
|
||||
print(f" Non-compact index: {non_compact_index_path}")
|
||||
print(f" FAISS baseline: {baseline_path}")
|
||||
print(f" Queries: {queries_file}")
|
||||
|
||||
print("\n🔧 Next steps:")
|
||||
print(f" Run evaluation: python evaluate_laion.py --index {compact_index_path}")
|
||||
print(f" Or compare with: python evaluate_laion.py --index {non_compact_index_path}")
|
||||
else:
|
||||
print(" Skipped building indexes")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Setup interrupted by user")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Setup failed: {e}")
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
342
benchmarks/llm_utils.py
Normal file
342
benchmarks/llm_utils.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
LLM utils for RAG benchmarks with Qwen3-8B and Qwen2.5-VL (multimodal)
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
|
||||
def is_qwen3_model(model_name):
|
||||
"""Check if model is Qwen3"""
|
||||
return "Qwen3" in model_name or "qwen3" in model_name.lower()
|
||||
|
||||
|
||||
def is_qwen_vl_model(model_name):
|
||||
"""Check if model is Qwen2.5-VL"""
|
||||
return "Qwen2.5-VL" in model_name or "qwen2.5-vl" in model_name.lower()
|
||||
|
||||
|
||||
def apply_qwen3_chat_template(tokenizer, prompt):
|
||||
"""Apply Qwen3 chat template with thinking enabled"""
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
|
||||
|
||||
def extract_thinking_answer(response):
|
||||
"""Extract final answer from Qwen3 thinking model response"""
|
||||
if "<think>" in response and "</think>" in response:
|
||||
try:
|
||||
think_end = response.index("</think>") + len("</think>")
|
||||
final_answer = response[think_end:].strip()
|
||||
return final_answer
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return response.strip()
|
||||
|
||||
|
||||
def load_hf_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load HuggingFace model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading HF: {model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def load_vllm_model(model_name="Qwen/Qwen3-8B", trust_remote_code=False):
|
||||
"""Load vLLM model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError("vllm not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading vLLM: {model_name}")
|
||||
llm = LLM(model=model_name, trust_remote_code=trust_remote_code)
|
||||
|
||||
# Qwen3 specific config
|
||||
if is_qwen3_model(model_name):
|
||||
stop_tokens = ["<|im_end|>", "<|end_of_text|>"]
|
||||
max_tokens = 2048
|
||||
else:
|
||||
stop_tokens = None
|
||||
max_tokens = 1024
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, max_tokens=max_tokens, stop=stop_tokens)
|
||||
return llm, sampling_params
|
||||
|
||||
|
||||
def generate_hf(tokenizer, model, prompt, max_tokens=None):
|
||||
"""Generate with HF - supports Qwen3 thinking models"""
|
||||
model_name = getattr(model, "name_or_path", "unknown")
|
||||
is_qwen3 = is_qwen3_model(model_name)
|
||||
|
||||
# Apply chat template for Qwen3
|
||||
if is_qwen3:
|
||||
prompt = apply_qwen3_chat_template(tokenizer, prompt)
|
||||
max_tokens = max_tokens or 2048
|
||||
else:
|
||||
max_tokens = max_tokens or 1024
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
response = response[len(prompt) :].strip()
|
||||
|
||||
# Extract final answer for thinking models
|
||||
if is_qwen3:
|
||||
return extract_thinking_answer(response)
|
||||
return response
|
||||
|
||||
|
||||
def generate_vllm(llm, sampling_params, prompt):
|
||||
"""Generate with vLLM - supports Qwen3 thinking models"""
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
response = outputs[0].outputs[0].text.strip()
|
||||
|
||||
# Extract final answer for Qwen3 thinking models
|
||||
model_name = str(llm.llm_engine.model_config.model)
|
||||
if is_qwen3_model(model_name):
|
||||
return extract_thinking_answer(response)
|
||||
return response
|
||||
|
||||
|
||||
def create_prompt(context, query, domain="default"):
|
||||
"""Create RAG prompt"""
|
||||
if domain == "emails":
|
||||
return f"Email content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
elif domain == "finance":
|
||||
return f"Financial content:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
elif domain == "multimodal":
|
||||
return f"Image context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
else:
|
||||
return f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
|
||||
|
||||
|
||||
def evaluate_rag(searcher, llm_func, queries, domain="default", top_k=3, complexity=64):
|
||||
"""Simple RAG evaluation with timing"""
|
||||
search_times = []
|
||||
gen_times = []
|
||||
results = []
|
||||
|
||||
for i, query in enumerate(queries):
|
||||
# Search
|
||||
start = time.time()
|
||||
docs = searcher.search(query, top_k=top_k, complexity=complexity)
|
||||
search_time = time.time() - start
|
||||
|
||||
# Generate
|
||||
context = "\n\n".join([doc.text for doc in docs])
|
||||
prompt = create_prompt(context, query, domain)
|
||||
|
||||
start = time.time()
|
||||
response = llm_func(prompt)
|
||||
gen_time = time.time() - start
|
||||
|
||||
search_times.append(search_time)
|
||||
gen_times.append(gen_time)
|
||||
results.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||
|
||||
return {
|
||||
"avg_search_time": sum(search_times) / len(search_times),
|
||||
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
def load_qwen_vl_model(model_name="Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=False):
|
||||
"""Load Qwen2.5-VL multimodal model
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model to load
|
||||
trust_remote_code (bool): Whether to allow execution of code from the model repository.
|
||||
Defaults to False for security. Only enable for trusted models.
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError("transformers not available")
|
||||
|
||||
if trust_remote_code:
|
||||
print(
|
||||
"⚠️ WARNING: Loading model with trust_remote_code=True. This can execute arbitrary code."
|
||||
)
|
||||
|
||||
print(f"Loading Qwen2.5-VL: {model_name}")
|
||||
|
||||
try:
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load with AutoModelForVision2Seq, trying specific class: {e}")
|
||||
|
||||
# Fallback to specific class
|
||||
try:
|
||||
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
return processor, model
|
||||
|
||||
except Exception as e2:
|
||||
raise ImportError(f"Failed to load Qwen2.5-VL model: {e2}")
|
||||
|
||||
|
||||
def generate_qwen_vl(processor, model, prompt, image_path=None, max_tokens=512):
|
||||
"""Generate with Qwen2.5-VL multimodal model"""
|
||||
from PIL import Image
|
||||
|
||||
# Prepare inputs
|
||||
if image_path:
|
||||
image = Image.open(image_path)
|
||||
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
||||
else:
|
||||
inputs = processor(text=prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**inputs, max_new_tokens=max_tokens, do_sample=False, temperature=0.1
|
||||
)
|
||||
|
||||
# Decode response
|
||||
generated_ids = generated_ids[:, inputs["input_ids"].shape[1] :]
|
||||
response = processor.decode(generated_ids[0], skip_special_tokens=True)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def create_multimodal_prompt(context, query, image_descriptions, task_type="images"):
|
||||
"""Create prompt for multimodal RAG"""
|
||||
if task_type == "images":
|
||||
return f"""Based on the retrieved images and their descriptions, answer the following question.
|
||||
|
||||
Retrieved Image Descriptions:
|
||||
{context}
|
||||
|
||||
Question: {query}
|
||||
|
||||
Provide a detailed answer based on the visual content described above."""
|
||||
|
||||
return f"Context: {context}\nQuestion: {query}\nAnswer:"
|
||||
|
||||
|
||||
def evaluate_multimodal_rag(searcher, queries, processor=None, model=None, complexity=64):
|
||||
"""Evaluate multimodal RAG with Qwen2.5-VL"""
|
||||
search_times = []
|
||||
gen_times = []
|
||||
results = []
|
||||
|
||||
for i, query_item in enumerate(queries):
|
||||
# Handle both string and dict formats for queries
|
||||
if isinstance(query_item, dict):
|
||||
query = query_item.get("query", "")
|
||||
image_path = query_item.get("image_path") # Optional reference image
|
||||
else:
|
||||
query = str(query_item)
|
||||
image_path = None
|
||||
|
||||
# Search
|
||||
start_time = time.time()
|
||||
search_results = searcher.search(query, top_k=3, complexity=complexity)
|
||||
search_time = time.time() - start_time
|
||||
search_times.append(search_time)
|
||||
|
||||
# Prepare context from search results
|
||||
context_parts = []
|
||||
for result in search_results:
|
||||
context_parts.append(f"- {result.text}")
|
||||
context = "\n".join(context_parts)
|
||||
|
||||
# Generate with multimodal model
|
||||
start_time = time.time()
|
||||
if processor and model:
|
||||
prompt = create_multimodal_prompt(context, query, context_parts)
|
||||
response = generate_qwen_vl(processor, model, prompt, image_path)
|
||||
else:
|
||||
response = f"Context: {context}"
|
||||
gen_time = time.time() - start_time
|
||||
|
||||
gen_times.append(gen_time)
|
||||
results.append(response)
|
||||
|
||||
if i < 3:
|
||||
print(f"Q{i + 1}: Search={search_time:.3f}s, Gen={gen_time:.3f}s")
|
||||
|
||||
return {
|
||||
"avg_search_time": sum(search_times) / len(search_times),
|
||||
"avg_generation_time": sum(gen_times) / len(gen_times),
|
||||
"results": results,
|
||||
}
|
||||
@@ -2,20 +2,20 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str
|
||||
batch_sizes: List[int]
|
||||
batch_sizes: list[int]
|
||||
seq_length: int
|
||||
num_runs: int
|
||||
use_fp16: bool = True
|
||||
@@ -28,47 +28,45 @@ class BenchmarkConfig:
|
||||
|
||||
class GraphContainer:
|
||||
"""Container for managing graphs for different batch sizes (CUDA graphs on NVIDIA, regular on others)."""
|
||||
|
||||
|
||||
def __init__(self, model: nn.Module, seq_length: int):
|
||||
self.model = model
|
||||
self.seq_length = seq_length
|
||||
self.graphs: Dict[int, 'GraphWrapper'] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> 'GraphWrapper':
|
||||
self.graphs: dict[int, GraphWrapper] = {}
|
||||
|
||||
def get_or_create(self, batch_size: int) -> "GraphWrapper":
|
||||
if batch_size not in self.graphs:
|
||||
self.graphs[batch_size] = GraphWrapper(
|
||||
self.model, batch_size, self.seq_length
|
||||
)
|
||||
self.graphs[batch_size] = GraphWrapper(self.model, batch_size, self.seq_length)
|
||||
return self.graphs[batch_size]
|
||||
|
||||
|
||||
class GraphWrapper:
|
||||
"""Wrapper for graph capture and replay (CUDA graphs on NVIDIA, regular on others)."""
|
||||
|
||||
|
||||
def __init__(self, model: nn.Module, batch_size: int, seq_length: int):
|
||||
self.model = model
|
||||
self.device = self._get_device()
|
||||
self.static_input = self._create_random_batch(batch_size, seq_length)
|
||||
self.static_attention_mask = torch.ones_like(self.static_input)
|
||||
|
||||
|
||||
# Warm up
|
||||
self._warmup()
|
||||
|
||||
|
||||
# Only use CUDA graphs on NVIDIA GPUs
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, 'CUDAGraph'):
|
||||
if torch.cuda.is_available() and hasattr(torch.cuda, "CUDAGraph"):
|
||||
# Capture graph
|
||||
self.graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.graph):
|
||||
self.static_output = self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
self.use_cuda_graph = True
|
||||
else:
|
||||
# For MPS or CPU, just store the model
|
||||
self.use_cuda_graph = False
|
||||
self.static_output = None
|
||||
|
||||
|
||||
def _get_device(self) -> str:
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
@@ -76,22 +74,20 @@ class GraphWrapper:
|
||||
return "mps"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int, seq_length: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000, (batch_size, seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
0, 1000, (batch_size, seq_length), device=self.device, dtype=torch.long
|
||||
)
|
||||
|
||||
|
||||
def _warmup(self, num_warmup: int = 3):
|
||||
with torch.no_grad():
|
||||
for _ in range(num_warmup):
|
||||
self.model(
|
||||
input_ids=self.static_input,
|
||||
attention_mask=self.static_attention_mask
|
||||
attention_mask=self.static_attention_mask,
|
||||
)
|
||||
|
||||
|
||||
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_cuda_graph:
|
||||
self.static_input.copy_(input_ids)
|
||||
@@ -105,14 +101,14 @@ class GraphWrapper:
|
||||
|
||||
class ModelOptimizer:
|
||||
"""Applies various optimizations to the model."""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def optimize(model: nn.Module, config: BenchmarkConfig) -> nn.Module:
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Cannot optimize None model")
|
||||
|
||||
|
||||
# Move to GPU
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
@@ -124,53 +120,59 @@ class ModelOptimizer:
|
||||
model = model.cpu()
|
||||
device = "cpu"
|
||||
print(f"- Model moved to {device}")
|
||||
|
||||
|
||||
# FP16
|
||||
if config.use_fp16 and not config.use_int4:
|
||||
model = model.half()
|
||||
# use torch compile
|
||||
model = torch.compile(model)
|
||||
print("- Using FP16 precision")
|
||||
|
||||
|
||||
# Check if using SDPA (only on CUDA)
|
||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and float(torch.version.cuda[:3]) >= 11.6
|
||||
):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
|
||||
# Flash Attention (only on CUDA)
|
||||
if config.use_flash_attention and torch.cuda.is_available():
|
||||
try:
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
from flash_attn.flash_attention import FlashAttention # noqa: F401
|
||||
|
||||
print("- Flash Attention 2 available")
|
||||
if hasattr(model.config, "attention_mode"):
|
||||
model.config.attention_mode = "flash_attention_2"
|
||||
print(" - Enabled Flash Attention 2 mode")
|
||||
except ImportError:
|
||||
print("- Flash Attention not available")
|
||||
|
||||
|
||||
# Memory efficient attention (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
from xformers.ops import memory_efficient_attention # noqa: F401
|
||||
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Handles accurate GPU timing using GPU events or CPU timing."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
if torch.cuda.is_available():
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
@@ -182,7 +184,7 @@ class Timer:
|
||||
else:
|
||||
# CPU timing
|
||||
self.use_gpu_timing = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timing(self):
|
||||
if self.use_gpu_timing:
|
||||
@@ -195,7 +197,7 @@ class Timer:
|
||||
start_time = time.time()
|
||||
yield
|
||||
self.cpu_elapsed = time.time() - start_time
|
||||
|
||||
|
||||
def elapsed_time(self) -> float:
|
||||
if self.use_gpu_timing:
|
||||
return self.start_event.elapsed_time(self.end_event) / 1000 # ms to seconds
|
||||
@@ -205,14 +207,14 @@ class Timer:
|
||||
|
||||
class Benchmark:
|
||||
"""Main benchmark runner."""
|
||||
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
try:
|
||||
self.model = self._load_model()
|
||||
if self.model is None:
|
||||
raise ValueError("Model initialization failed - model is None")
|
||||
|
||||
|
||||
# Only use CUDA graphs on NVIDIA GPUs
|
||||
if config.use_cuda_graphs and torch.cuda.is_available():
|
||||
self.graphs = GraphContainer(self.model, config.seq_length)
|
||||
@@ -220,25 +222,27 @@ class Benchmark:
|
||||
self.graphs = None
|
||||
self.timer = Timer()
|
||||
except Exception as e:
|
||||
print(f"ERROR in benchmark initialization: {str(e)}")
|
||||
print(f"ERROR in benchmark initialization: {e!s}")
|
||||
raise
|
||||
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
|
||||
try:
|
||||
# Int4 quantization using HuggingFace integration
|
||||
if self.config.use_int4:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
print(f"- bitsandbytes version: {bnb.__version__}")
|
||||
|
||||
# 检查是否使用自定义的8bit量化
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
|
||||
# Check if using custom 8bit quantization
|
||||
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||
print("- Using custom Linear8bitLt replacement for all linear layers")
|
||||
|
||||
# 加载原始模型(不使用量化配置)
|
||||
|
||||
# Load original model (without quantization config)
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
# set default to half
|
||||
torch.set_default_dtype(torch.float16)
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
@@ -246,112 +250,121 @@ class Benchmark:
|
||||
self.config.model_path,
|
||||
torch_dtype=compute_dtype,
|
||||
)
|
||||
|
||||
# 定义替换函数
|
||||
|
||||
# Define replacement function
|
||||
def replace_linear_with_linear8bitlt(model):
|
||||
"""递归地将模型中的所有nn.Linear层替换为Linear8bitLt"""
|
||||
"""Recursively replace all nn.Linear layers with Linear8bitLt"""
|
||||
for name, module in list(model.named_children()):
|
||||
if isinstance(module, nn.Linear):
|
||||
# 获取原始线性层的参数
|
||||
# Get original linear layer parameters
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
|
||||
# 创建8bit线性层
|
||||
|
||||
# Create 8bit linear layer
|
||||
# print size
|
||||
print(f"in_features: {in_features}, out_features: {out_features}")
|
||||
new_module = bnb.nn.Linear8bitLt(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=False,
|
||||
)
|
||||
|
||||
# 复制权重和偏置
|
||||
|
||||
# Copy weights and bias
|
||||
new_module.weight.data = module.weight.data
|
||||
if bias:
|
||||
new_module.bias.data = module.bias.data
|
||||
|
||||
# 替换模块
|
||||
|
||||
# Replace module
|
||||
setattr(model, name, new_module)
|
||||
else:
|
||||
# 递归处理子模块
|
||||
# Process child modules recursively
|
||||
replace_linear_with_linear8bitlt(module)
|
||||
|
||||
|
||||
return model
|
||||
|
||||
# 替换所有线性层
|
||||
|
||||
# Replace all linear layers
|
||||
model = replace_linear_with_linear8bitlt(model)
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
# 将模型移到GPU(量化发生在这里)
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
|
||||
# Move model to GPU (quantization happens here)
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
model = model.to(device)
|
||||
|
||||
|
||||
print("- All linear layers replaced with Linear8bitLt")
|
||||
|
||||
|
||||
else:
|
||||
# 使用原来的Int4量化方法
|
||||
# Use original Int4 quantization method
|
||||
print("- Using bitsandbytes for Int4 quantization")
|
||||
|
||||
|
||||
# Create quantization config
|
||||
|
||||
|
||||
compute_dtype = torch.float16 if self.config.use_fp16 else torch.float32
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
|
||||
print("- Quantization config:", quantization_config)
|
||||
|
||||
|
||||
# Load model directly with quantization config
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto" # Let HF decide on device mapping
|
||||
device_map="auto", # Let HF decide on device mapping
|
||||
)
|
||||
|
||||
|
||||
# Check if model loaded successfully
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
|
||||
# Apply optimizations directly here
|
||||
print("\nApplying model optimizations:")
|
||||
|
||||
if hasattr(self.config, 'use_linear8bitlt') and self.config.use_linear8bitlt:
|
||||
|
||||
if hasattr(self.config, "use_linear8bitlt") and self.config.use_linear8bitlt:
|
||||
print("- Model moved to GPU with Linear8bitLt quantization")
|
||||
else:
|
||||
# Skip moving to GPU since device_map="auto" already did that
|
||||
print("- Model already on GPU due to device_map='auto'")
|
||||
|
||||
|
||||
# Skip FP16 conversion since we specified compute_dtype
|
||||
print(f"- Using {compute_dtype} for compute dtype")
|
||||
|
||||
|
||||
# Check CUDA and SDPA
|
||||
if torch.cuda.is_available() and torch.version.cuda and float(torch.version.cuda[:3]) >= 11.6:
|
||||
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and float(torch.version.cuda[:3]) >= 11.6
|
||||
):
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
print("- Using PyTorch SDPA (scaled_dot_product_attention)")
|
||||
else:
|
||||
print("- PyTorch SDPA not available")
|
||||
|
||||
|
||||
# Try xformers if available (only on CUDA)
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention
|
||||
if hasattr(model, 'enable_xformers_memory_efficient_attention'):
|
||||
if hasattr(model, "enable_xformers_memory_efficient_attention"):
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
print("- Enabled xformers memory efficient attention")
|
||||
else:
|
||||
print("- Model doesn't support xformers")
|
||||
except (ImportError, AttributeError):
|
||||
print("- Xformers not available")
|
||||
|
||||
|
||||
# Set to eval mode
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
@@ -365,76 +378,83 @@ class Benchmark:
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
)
|
||||
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
self.config.model_path,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=compute_dtype,
|
||||
device_map="auto"
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
|
||||
if model is None:
|
||||
raise ValueError("Model loading returned None")
|
||||
|
||||
|
||||
print(f"- Model type: {type(model)}")
|
||||
model.eval()
|
||||
print("- Model set to eval mode")
|
||||
|
||||
|
||||
else:
|
||||
# Standard loading for FP16/FP32
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
print("- Model loaded in standard precision")
|
||||
print(f"- Model type: {type(model)}")
|
||||
|
||||
|
||||
# Apply standard optimizations
|
||||
# set default to half
|
||||
import torch
|
||||
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = ModelOptimizer.optimize(model, self.config)
|
||||
model = model.half()
|
||||
# add torch compile
|
||||
model = torch.compile(model)
|
||||
|
||||
|
||||
# Final check to ensure model is not None
|
||||
if model is None:
|
||||
raise ValueError("Model is None after optimization")
|
||||
|
||||
|
||||
print(f"- Final model type: {type(model)}")
|
||||
return model
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"ERROR loading model: {str(e)}")
|
||||
print(f"ERROR loading model: {e!s}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=device,
|
||||
dtype=torch.long
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
def _run_inference(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
graph_wrapper: Optional[GraphWrapper] = None
|
||||
) -> Tuple[float, torch.Tensor]:
|
||||
self, input_ids: torch.Tensor, graph_wrapper: GraphWrapper | None = None
|
||||
) -> tuple[float, torch.Tensor]:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
|
||||
with torch.no_grad(), self.timer.timing():
|
||||
if graph_wrapper is not None:
|
||||
output = graph_wrapper(input_ids, attention_mask)
|
||||
else:
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
return self.timer.elapsed_time(), output
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
|
||||
# Reset peak memory stats
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
@@ -443,22 +463,20 @@ class Benchmark:
|
||||
pass
|
||||
else:
|
||||
print("- No GPU memory stats available")
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
|
||||
# Get or create graph for this batch size
|
||||
graph_wrapper = (
|
||||
self.graphs.get_or_create(batch_size)
|
||||
if self.graphs is not None
|
||||
else None
|
||||
self.graphs.get_or_create(batch_size) if self.graphs is not None else None
|
||||
)
|
||||
|
||||
|
||||
# Pre-allocate input tensor
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
@@ -469,44 +487,44 @@ class Benchmark:
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
print(f"No successful runs for batch size {batch_size}, skipping")
|
||||
continue
|
||||
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
# Log memory usage
|
||||
if torch.cuda.is_available():
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
elif torch.backends.mps.is_available():
|
||||
# MPS doesn't have max_memory_allocated, use 0
|
||||
peak_memory_gb = 0.0
|
||||
else:
|
||||
peak_memory_gb = 0.0
|
||||
print("- No GPU memory usage available")
|
||||
|
||||
|
||||
if peak_memory_gb > 0:
|
||||
print(f"\nPeak GPU memory usage: {peak_memory_gb:.2f} GB")
|
||||
else:
|
||||
print("\n- GPU memory usage not available")
|
||||
|
||||
|
||||
# Add memory info to results
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -566,14 +584,14 @@ def main():
|
||||
action="store_true",
|
||||
help="Enable Linear8bitLt quantization for all linear layers",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# Print arguments for debugging
|
||||
print("\nCommand line arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f"- {arg}: {value}")
|
||||
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path=args.model_path,
|
||||
batch_sizes=[int(bs) for bs in args.batch_sizes.split(",")],
|
||||
@@ -586,45 +604,56 @@ def main():
|
||||
use_flash_attention=args.use_flash_attention,
|
||||
use_linear8bitlt=args.use_linear8bitlt,
|
||||
)
|
||||
|
||||
|
||||
# Print configuration for debugging
|
||||
print("\nBenchmark configuration:")
|
||||
for field, value in vars(config).items():
|
||||
print(f"- {field}: {value}")
|
||||
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
# Save results to file
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
# Create results directory if it doesn't exist
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
|
||||
# Generate filename based on configuration
|
||||
precision_type = "int4" if config.use_int4 else "int8" if config.use_int8 else "fp16" if config.use_fp16 else "fp32"
|
||||
precision_type = (
|
||||
"int4"
|
||||
if config.use_int4
|
||||
else "int8"
|
||||
if config.use_int8
|
||||
else "fp16"
|
||||
if config.use_fp16
|
||||
else "fp32"
|
||||
)
|
||||
model_name = os.path.basename(config.model_path)
|
||||
output_file = f"results/benchmark_{model_name}_{precision_type}.json"
|
||||
|
||||
|
||||
# Save results
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"config": {k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()},
|
||||
"results": {str(k): v for k, v in results.items()}
|
||||
},
|
||||
f,
|
||||
indent=2
|
||||
"config": {
|
||||
k: str(v) if isinstance(v, list) else v for k, v in vars(config).items()
|
||||
},
|
||||
"results": {str(k): v for k, v in results.items()},
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
print(f"Results saved to {output_file}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
@@ -5,24 +5,21 @@ It correctly compares results by fetching the text content for both the new sear
|
||||
results and the golden standard results, making the comparison robust to ID changes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
from leann.api import LeannSearcher, LeannBuilder
|
||||
import numpy as np
|
||||
from leann.api import LeannBuilder, LeannChat, LeannSearcher
|
||||
|
||||
|
||||
def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
"""Checks if the data directory exists, and if not, downloads it from HF Hub."""
|
||||
if not data_root.exists():
|
||||
print(f"Data directory '{data_root}' not found.")
|
||||
print(
|
||||
"Downloading evaluation data from Hugging Face Hub... (this may take a moment)"
|
||||
)
|
||||
print("Downloading evaluation data from Hugging Face Hub... (this may take a moment)")
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
@@ -56,14 +53,14 @@ def download_data_if_needed(data_root: Path, download_embeddings: bool = False):
|
||||
print(
|
||||
"Error: huggingface_hub is not installed. Please install it to download the data:"
|
||||
)
|
||||
print("uv pip install -e '.[dev]'")
|
||||
print("uv sync --only-group dev")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during data download: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
def download_embeddings_if_needed(data_root: Path, dataset_type: str | None = None):
|
||||
"""Download embeddings files specifically."""
|
||||
embeddings_dir = data_root / "embeddings"
|
||||
|
||||
@@ -101,7 +98,7 @@ def download_embeddings_if_needed(data_root: Path, dataset_type: str = None):
|
||||
|
||||
|
||||
# --- Helper Function to get Golden Passages ---
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
def get_golden_texts(searcher: LeannSearcher, golden_ids: list[int]) -> set:
|
||||
"""
|
||||
Retrieves the text for golden passage IDs directly from the LeannSearcher's
|
||||
passage manager.
|
||||
@@ -113,24 +110,20 @@ def get_golden_texts(searcher: LeannSearcher, golden_ids: List[int]) -> set:
|
||||
passage_data = searcher.passage_manager.get_passage(str(gid))
|
||||
golden_texts.add(passage_data["text"])
|
||||
except KeyError:
|
||||
print(
|
||||
f"Warning: Golden passage ID '{gid}' not found in the index's passage data."
|
||||
)
|
||||
print(f"Warning: Golden passage ID '{gid}' not found in the index's passage data.")
|
||||
return golden_texts
|
||||
|
||||
|
||||
def load_queries(file_path: Path) -> List[str]:
|
||||
def load_queries(file_path: Path) -> list[str]:
|
||||
queries = []
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
data = json.loads(line)
|
||||
queries.append(data["query"])
|
||||
return queries
|
||||
|
||||
|
||||
def build_index_from_embeddings(
|
||||
embeddings_file: str, output_path: str, backend: str = "hnsw"
|
||||
):
|
||||
def build_index_from_embeddings(embeddings_file: str, output_path: str, backend: str = "hnsw"):
|
||||
"""
|
||||
Build a LEANN index from pre-computed embeddings.
|
||||
|
||||
@@ -173,9 +166,7 @@ def build_index_from_embeddings(
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run recall evaluation on a LEANN index."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Run recall evaluation on a LEANN index.")
|
||||
parser.add_argument(
|
||||
"index_path",
|
||||
type=str,
|
||||
@@ -202,26 +193,41 @@ def main():
|
||||
parser.add_argument(
|
||||
"--num-queries", type=int, default=10, help="Number of queries to evaluate."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k", type=int, default=3, help="The 'k' value for recall@k."
|
||||
)
|
||||
parser.add_argument("--top-k", type=int, default=3, help="The 'k' value for recall@k.")
|
||||
parser.add_argument(
|
||||
"--ef-search", type=int, default=120, help="The 'efSearch' parameter for HNSW."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Batch size for HNSW batched search (0 disables batching)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-type",
|
||||
type=str,
|
||||
choices=["ollama", "hf", "openai", "gemini", "simulated"],
|
||||
default="ollama",
|
||||
help="LLM backend type to optionally query during evaluation (default: ollama)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model",
|
||||
type=str,
|
||||
default="qwen3:1.7b",
|
||||
help="LLM model identifier for the chosen backend (default: qwen3:1.7b)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# --- Path Configuration ---
|
||||
# Assumes a project structure where the script is in 'examples/'
|
||||
# and data is in 'data/' at the project root.
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
data_root = project_root / "data"
|
||||
# Assumes a project structure where the script is in 'benchmarks/'
|
||||
# and evaluation data is in 'benchmarks/data/'.
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
data_root = script_dir / "data"
|
||||
|
||||
# Download data based on mode
|
||||
if args.mode == "build":
|
||||
# For building mode, we need embeddings
|
||||
download_data_if_needed(
|
||||
data_root, download_embeddings=False
|
||||
) # Basic data first
|
||||
download_data_if_needed(data_root, download_embeddings=False) # Basic data first
|
||||
|
||||
# Auto-detect dataset type and download embeddings
|
||||
if args.embeddings_file:
|
||||
@@ -262,9 +268,7 @@ def main():
|
||||
print(f"Index built successfully: {built_index_path}")
|
||||
|
||||
# Ask if user wants to run evaluation
|
||||
eval_response = (
|
||||
input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
)
|
||||
eval_response = input("Run evaluation on the built index? (y/n): ").strip().lower()
|
||||
if eval_response != "y":
|
||||
print("Index building complete. Exiting.")
|
||||
return
|
||||
@@ -293,11 +297,9 @@ def main():
|
||||
break
|
||||
|
||||
if not args.index_path:
|
||||
print("No indices found. The data download should have included pre-built indices.")
|
||||
print(
|
||||
"No indices found. The data download should have included pre-built indices."
|
||||
)
|
||||
print(
|
||||
"Please check the data/indices/ directory or provide --index-path manually."
|
||||
"Please check the benchmarks/data/indices/ directory or provide --index-path manually."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
@@ -310,14 +312,10 @@ def main():
|
||||
else:
|
||||
# Fallback: try to infer from the index directory name
|
||||
dataset_type = Path(args.index_path).name
|
||||
print(
|
||||
f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'."
|
||||
)
|
||||
print(f"WARNING: Could not detect dataset type from path, inferred '{dataset_type}'.")
|
||||
|
||||
queries_file = data_root / "queries" / "nq_open.jsonl"
|
||||
golden_results_file = (
|
||||
data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
)
|
||||
golden_results_file = data_root / "ground_truth" / dataset_type / "flat_results_nq_k3.json"
|
||||
|
||||
print(f"INFO: Detected dataset type: {dataset_type}")
|
||||
print(f"INFO: Using queries file: {queries_file}")
|
||||
@@ -327,7 +325,7 @@ def main():
|
||||
searcher = LeannSearcher(args.index_path)
|
||||
queries = load_queries(queries_file)
|
||||
|
||||
with open(golden_results_file, "r") as f:
|
||||
with open(golden_results_file) as f:
|
||||
golden_results_data = json.load(f)
|
||||
|
||||
num_eval_queries = min(args.num_queries, len(queries))
|
||||
@@ -340,10 +338,23 @@ def main():
|
||||
for i in range(num_eval_queries):
|
||||
start_time = time.time()
|
||||
new_results = searcher.search(
|
||||
queries[i], top_k=args.top_k, ef=args.ef_search
|
||||
queries[i],
|
||||
top_k=args.top_k,
|
||||
complexity=args.ef_search,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
search_times.append(time.time() - start_time)
|
||||
|
||||
# Optional: also call the LLM with configurable backend/model (does not affect recall)
|
||||
llm_config = {"type": args.llm_type, "model": args.llm_model}
|
||||
chat = LeannChat(args.index_path, llm_config=llm_config, searcher=searcher)
|
||||
answer = chat.ask(
|
||||
queries[i],
|
||||
top_k=args.top_k,
|
||||
complexity=args.ef_search,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
print(f"Answer: {answer}")
|
||||
# Correct Recall Calculation: Based on TEXT content
|
||||
new_texts = {result.text for result in new_results}
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BitsAndBytesConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModel
|
||||
|
||||
# Add MLX imports
|
||||
try:
|
||||
import mlx.core as mx
|
||||
from mlx_lm.utils import load
|
||||
|
||||
MLX_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
print("MLX not available. Install with: uv pip install mlx mlx-lm")
|
||||
MLX_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
model_path: str = "facebook/contriever"
|
||||
batch_sizes: List[int] = None
|
||||
model_path: str = "facebook/contriever-msmarco"
|
||||
batch_sizes: list[int] = None
|
||||
seq_length: int = 256
|
||||
num_runs: int = 5
|
||||
use_fp16: bool = True
|
||||
@@ -30,18 +31,19 @@ class BenchmarkConfig:
|
||||
use_flash_attention: bool = False
|
||||
use_linear8bitlt: bool = False
|
||||
use_mlx: bool = False # New flag for MLX testing
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
if self.batch_sizes is None:
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64]
|
||||
self.batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
|
||||
class MLXBenchmark:
|
||||
"""MLX-specific benchmark for embedding models"""
|
||||
|
||||
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.model, self.tokenizer = self._load_model()
|
||||
|
||||
|
||||
def _load_model(self):
|
||||
"""Load MLX model and tokenizer following the API pattern"""
|
||||
print(f"Loading MLX model from {self.config.model_path}...")
|
||||
@@ -52,55 +54,51 @@ class MLXBenchmark:
|
||||
except Exception as e:
|
||||
print(f"Error loading MLX model: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int):
|
||||
"""Create random input batches for MLX testing - same as PyTorch"""
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
return torch.randint(0, 1000, (batch_size, self.config.seq_length), dtype=torch.long)
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
"""Run MLX inference with same input as PyTorch"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Convert PyTorch tensor to MLX array
|
||||
input_ids_mlx = mx.array(input_ids.numpy())
|
||||
|
||||
|
||||
# Get embeddings
|
||||
embeddings = self.model(input_ids_mlx)
|
||||
|
||||
|
||||
# Mean pooling (following the API pattern)
|
||||
pooled = embeddings.mean(axis=1)
|
||||
|
||||
|
||||
# Convert to numpy (following the API pattern)
|
||||
pooled_numpy = np.array(pooled.tolist(), dtype=np.float32)
|
||||
|
||||
|
||||
# Force computation
|
||||
_ = pooled_numpy.shape
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"MLX inference error: {e}")
|
||||
return float('inf')
|
||||
return float("inf")
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
"""Run the MLX benchmark across all batch sizes"""
|
||||
results = {}
|
||||
|
||||
|
||||
print(f"Starting MLX benchmark with model: {self.config.model_path}")
|
||||
print(f"Testing batch sizes: {self.config.batch_sizes}")
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\n=== Testing MLX batch size: {batch_size} ===")
|
||||
times = []
|
||||
|
||||
|
||||
# Create input batch (same as PyTorch)
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
|
||||
# Warm up
|
||||
print("Warming up...")
|
||||
for _ in range(3):
|
||||
@@ -109,26 +107,26 @@ class MLXBenchmark:
|
||||
except Exception as e:
|
||||
print(f"Warmup error: {e}")
|
||||
break
|
||||
|
||||
|
||||
# Run benchmark
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||
for _i in tqdm(range(self.config.num_runs), desc=f"MLX Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time = self._run_inference(input_ids)
|
||||
if elapsed_time != float('inf'):
|
||||
if elapsed_time != float("inf"):
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during MLX inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
print(f"Skipping batch size {batch_size} due to errors")
|
||||
continue
|
||||
|
||||
|
||||
# Calculate statistics
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
@@ -136,122 +134,133 @@ class MLXBenchmark:
|
||||
"min_time": np.min(times),
|
||||
"max_time": np.max(times),
|
||||
}
|
||||
|
||||
|
||||
print(f"MLX Results for batch size {batch_size}:")
|
||||
print(f" Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f" Min Time: {np.min(times):.4f}s")
|
||||
print(f" Max Time: {np.max(times):.4f}s")
|
||||
print(f" Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self, config: BenchmarkConfig):
|
||||
self.config = config
|
||||
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
self.device = (
|
||||
"cuda"
|
||||
if torch.cuda.is_available()
|
||||
else "mps"
|
||||
if torch.backends.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
self.model = self._load_model()
|
||||
|
||||
|
||||
def _load_model(self) -> nn.Module:
|
||||
print(f"Loading model from {self.config.model_path}...")
|
||||
|
||||
|
||||
|
||||
model = AutoModel.from_pretrained(self.config.model_path)
|
||||
if self.config.use_fp16:
|
||||
model = model.half()
|
||||
model = torch.compile(model)
|
||||
model = model.to(self.device)
|
||||
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def _create_random_batch(self, batch_size: int) -> torch.Tensor:
|
||||
return torch.randint(
|
||||
0, 1000,
|
||||
0,
|
||||
1000,
|
||||
(batch_size, self.config.seq_length),
|
||||
device=self.device,
|
||||
dtype=torch.long
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
|
||||
def _run_inference(self, input_ids: torch.Tensor) -> float:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# print shape of input_ids and attention_mask
|
||||
print(f"input_ids shape: {input_ids.shape}")
|
||||
print(f"attention_mask shape: {attention_mask.shape}")
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
output = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
return end_time - start_time
|
||||
|
||||
def run(self) -> Dict[int, Dict[str, float]]:
|
||||
|
||||
def run(self) -> dict[int, dict[str, float]]:
|
||||
results = {}
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
for batch_size in self.config.batch_sizes:
|
||||
print(f"\nTesting batch size: {batch_size}")
|
||||
times = []
|
||||
|
||||
|
||||
input_ids = self._create_random_batch(batch_size)
|
||||
|
||||
for i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
|
||||
for _i in tqdm(range(self.config.num_runs), desc=f"Batch size {batch_size}"):
|
||||
try:
|
||||
elapsed_time = self._run_inference(input_ids)
|
||||
times.append(elapsed_time)
|
||||
except Exception as e:
|
||||
print(f"Error during inference: {e}")
|
||||
break
|
||||
|
||||
|
||||
if not times:
|
||||
continue
|
||||
|
||||
|
||||
avg_time = np.mean(times)
|
||||
std_time = np.std(times)
|
||||
throughput = batch_size / avg_time
|
||||
|
||||
|
||||
results[batch_size] = {
|
||||
"avg_time": avg_time,
|
||||
"std_time": std_time,
|
||||
"throughput": throughput,
|
||||
}
|
||||
|
||||
|
||||
print(f"Avg Time: {avg_time:.4f}s ± {std_time:.4f}s")
|
||||
print(f"Throughput: {throughput:.2f} sequences/second")
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
||||
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
else:
|
||||
peak_memory_gb = 0.0
|
||||
|
||||
|
||||
for batch_size in results:
|
||||
results[batch_size]["peak_memory_gb"] = peak_memory_gb
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def run_benchmark():
|
||||
"""Main function to run the benchmark with optimized parameters."""
|
||||
config = BenchmarkConfig()
|
||||
|
||||
|
||||
try:
|
||||
benchmark = Benchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
|
||||
return {
|
||||
"max_throughput": max_throughput,
|
||||
"avg_throughput": avg_throughput,
|
||||
"results": results
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Benchmark failed: {e}")
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||
|
||||
|
||||
def run_mlx_benchmark():
|
||||
"""Run MLX-specific benchmark"""
|
||||
@@ -260,55 +269,49 @@ def run_mlx_benchmark():
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "MLX not available"
|
||||
"error": "MLX not available",
|
||||
}
|
||||
|
||||
config = BenchmarkConfig(
|
||||
model_path="mlx-community/all-MiniLM-L6-v2-4bit",
|
||||
use_mlx=True
|
||||
)
|
||||
|
||||
|
||||
config = BenchmarkConfig(model_path="mlx-community/all-MiniLM-L6-v2-4bit", use_mlx=True)
|
||||
|
||||
try:
|
||||
benchmark = MLXBenchmark(config)
|
||||
results = benchmark.run()
|
||||
|
||||
|
||||
if not results:
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": "No valid results"
|
||||
"error": "No valid results",
|
||||
}
|
||||
|
||||
|
||||
max_throughput = max(results[batch_size]["throughput"] for batch_size in results)
|
||||
avg_throughput = np.mean([results[batch_size]["throughput"] for batch_size in results])
|
||||
|
||||
|
||||
return {
|
||||
"max_throughput": max_throughput,
|
||||
"avg_throughput": avg_throughput,
|
||||
"results": results
|
||||
"results": results,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"MLX benchmark failed: {e}")
|
||||
return {
|
||||
"max_throughput": 0.0,
|
||||
"avg_throughput": 0.0,
|
||||
"error": str(e)
|
||||
}
|
||||
return {"max_throughput": 0.0, "avg_throughput": 0.0, "error": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== PyTorch Benchmark ===")
|
||||
pytorch_result = run_benchmark()
|
||||
print(f"PyTorch Max throughput: {pytorch_result['max_throughput']:.2f} sequences/second")
|
||||
print(f"PyTorch Average throughput: {pytorch_result['avg_throughput']:.2f} sequences/second")
|
||||
|
||||
|
||||
print("\n=== MLX Benchmark ===")
|
||||
mlx_result = run_mlx_benchmark()
|
||||
print(f"MLX Max throughput: {mlx_result['max_throughput']:.2f} sequences/second")
|
||||
print(f"MLX Average throughput: {mlx_result['avg_throughput']:.2f} sequences/second")
|
||||
|
||||
|
||||
# Compare results
|
||||
if pytorch_result['max_throughput'] > 0 and mlx_result['max_throughput'] > 0:
|
||||
speedup = mlx_result['max_throughput'] / pytorch_result['max_throughput']
|
||||
print(f"\n=== Comparison ===")
|
||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||
if pytorch_result["max_throughput"] > 0 and mlx_result["max_throughput"] > 0:
|
||||
speedup = mlx_result["max_throughput"] / pytorch_result["max_throughput"]
|
||||
print("\n=== Comparison ===")
|
||||
print(f"MLX is {speedup:.2f}x {'faster' if speedup > 1 else 'slower'} than PyTorch")
|
||||
143
benchmarks/update/README.md
Normal file
143
benchmarks/update/README.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Update Benchmarks
|
||||
|
||||
This directory hosts two benchmark suites that exercise LEANN’s HNSW “update +
|
||||
search” pipeline under different assumptions:
|
||||
|
||||
1. **RNG recompute latency** – measure how random-neighbour pruning and cache
|
||||
settings influence incremental `add()` latency when embeddings are fetched
|
||||
over the ZMQ embedding server.
|
||||
2. **Update strategy comparison** – compare a fully sequential update pipeline
|
||||
against an offline approach that keeps the graph static and fuses results.
|
||||
|
||||
Both suites build a non-compact, `is_recompute=True` index so that new
|
||||
embeddings are pulled from the embedding server. Benchmark outputs are written
|
||||
under `.leann/bench/` by default and appended to CSV files for later plotting.
|
||||
|
||||
## Benchmarks
|
||||
|
||||
### 1. HNSW RNG Recompute Benchmark
|
||||
|
||||
`bench_hnsw_rng_recompute.py` evaluates incremental update latency under four
|
||||
random-neighbour (RNG) configurations. Each scenario uses the same dataset but
|
||||
changes the forward / reverse RNG pruning flags and whether the embedding cache
|
||||
is enabled:
|
||||
|
||||
| Scenario name | Forward RNG | Reverse RNG | ZMQ embedding cache |
|
||||
| ---------------------------------- | ----------- | ----------- | ------------------- |
|
||||
| `baseline` | Enabled | Enabled | Enabled |
|
||||
| `no_cache_baseline` | Enabled | Enabled | **Disabled** |
|
||||
| `disable_forward_rng` | **Disabled**| Enabled | Enabled |
|
||||
| `disable_forward_and_reverse_rng` | **Disabled**| **Disabled**| Enabled |
|
||||
|
||||
For each scenario the script:
|
||||
1. (Re)builds a `is_recompute=True` index and writes it to `.leann/bench/`.
|
||||
2. Starts `leann_backend_hnsw.hnsw_embedding_server` for remote embeddings.
|
||||
3. Appends the requested updates using the scenario’s RNG flags.
|
||||
4. Records total time, latency per passage, ZMQ fetch counts, and stage-level
|
||||
timings before appending a row to the CSV output.
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
LEANN_HNSW_LOG_PATH=.leann/bench/hnsw_server.log \
|
||||
LEANN_LOG_LEVEL=INFO \
|
||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||
--runs 1 \
|
||||
--index-path .leann/bench/test.leann \
|
||||
--initial-files data/PrideandPrejudice.txt \
|
||||
--update-files data/huawei_pangu.md \
|
||||
--max-initial 300 \
|
||||
--max-updates 1 \
|
||||
--add-timeout 120
|
||||
```
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/bench_results.csv` – per-scenario timing statistics
|
||||
(including ms/passage) for each run.
|
||||
- `.leann/bench/hnsw_server.log` – detailed ZMQ/server logs (path controlled by
|
||||
`LEANN_HNSW_LOG_PATH`).
|
||||
_The reference CSVs checked into this branch were generated on a workstation with an NVIDIA RTX 4090 GPU; throughput numbers will differ on other hardware._
|
||||
|
||||
### 2. Sequential vs. Offline Update Benchmark
|
||||
|
||||
`bench_update_vs_offline_search.py` compares two end-to-end strategies on the
|
||||
same dataset:
|
||||
|
||||
- **Scenario A – Sequential Update**
|
||||
- Start an embedding server.
|
||||
- Sequentially call `index.add()`; each call fetches embeddings via ZMQ and
|
||||
mutates the HNSW graph.
|
||||
- After all inserts, run a search on the updated graph.
|
||||
- Metrics recorded: update time (`add_total_s`), post-update search time
|
||||
(`search_time_s`), combined total (`total_time_s`), and per-passage
|
||||
latency.
|
||||
|
||||
- **Scenario B – Offline Embedding + Concurrent Search**
|
||||
- Stop Scenario A’s server and start a fresh embedding server.
|
||||
- Spawn two threads: one generates embeddings for the new passages offline
|
||||
(graph unchanged); the other computes the query embedding and searches the
|
||||
existing graph.
|
||||
- Merge offline similarities with the graph search results to emulate late
|
||||
fusion, then report the merged top‑k preview.
|
||||
- Metrics recorded: embedding time (`emb_time_s`), search time
|
||||
(`search_time_s`), concurrent makespan (`makespan_s`), and scenario total.
|
||||
|
||||
**Run (both scenarios):**
|
||||
```bash
|
||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||
--index-path .leann/bench/offline_vs_update.leann \
|
||||
--max-initial 300 \
|
||||
--num-updates 1
|
||||
```
|
||||
|
||||
You can pass `--only A` or `--only B` to run a single scenario. The script will
|
||||
print timing summaries to stdout and append the results to CSV.
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/offline_vs_update.csv` – per-scenario timing statistics for
|
||||
Scenario A and B.
|
||||
- Console output includes Scenario B’s merged top‑k preview for quick sanity
|
||||
checks.
|
||||
_The sample results committed here come from runs on an RTX 4090-equipped machine; expect variations if you benchmark on different GPUs._
|
||||
|
||||
### 3. Visualisation
|
||||
|
||||
`plot_bench_results.py` combines the RNG benchmark and the update strategy
|
||||
benchmark into a single two-panel plot.
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
uv run -m benchmarks.update.plot_bench_results \
|
||||
--csv benchmarks/update/bench_results.csv \
|
||||
--csv-right benchmarks/update/offline_vs_update.csv \
|
||||
--out benchmarks/update/bench_latency_from_csv.png
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--broken-y` – Enable a broken Y-axis (default: true when appropriate).
|
||||
- `--csv` – RNG benchmark results CSV (left panel).
|
||||
- `--csv-right` – Update strategy results CSV (right panel).
|
||||
- `--out` – Output image path (PNG/PDF supported).
|
||||
|
||||
**Output:**
|
||||
- `benchmarks/update/bench_latency_from_csv.png` – visual comparison of the two
|
||||
suites.
|
||||
- `benchmarks/update/bench_latency_from_csv.pdf` – PDF version, suitable for
|
||||
slides/papers.
|
||||
|
||||
## Parameters & Environment
|
||||
|
||||
### Common CLI Flags
|
||||
- `--max-initial` – Number of initial passages used to seed the index.
|
||||
- `--max-updates` / `--num-updates` – Number of passages to treat as updates.
|
||||
- `--index-path` – Base path (without extension) where the LEANN index is stored.
|
||||
- `--runs` – Number of repetitions (RNG benchmark only).
|
||||
|
||||
### Environment Variables
|
||||
- `LEANN_HNSW_LOG_PATH` – File to receive embedding-server logs (optional).
|
||||
- `LEANN_LOG_LEVEL` – Logging verbosity (DEBUG/INFO/WARNING/ERROR).
|
||||
- `CUDA_VISIBLE_DEVICES` – Set to empty string if you want to force CPU
|
||||
execution of the embedding model.
|
||||
|
||||
With these scripts you can easily replicate LEANN’s update benchmarks, compare
|
||||
multiple RNG strategies, and evaluate whether sequential updates or offline
|
||||
fusion better match your latency/accuracy trade-offs.
|
||||
16
benchmarks/update/__init__.py
Normal file
16
benchmarks/update/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Benchmarks for LEANN update workflows."""
|
||||
|
||||
# Expose helper to locate repository root for other modules that need it.
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def find_repo_root() -> Path:
|
||||
"""Return the project root containing pyproject.toml."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
return current.parents[1]
|
||||
|
||||
|
||||
__all__ = ["find_repo_root"]
|
||||
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
804
benchmarks/update/bench_hnsw_rng_recompute.py
Normal file
@@ -0,0 +1,804 @@
|
||||
"""Benchmark incremental HNSW add() under different RNG pruning modes with real
|
||||
embedding recomputation.
|
||||
|
||||
This script clones the structure of ``examples/dynamic_update_no_recompute.py``
|
||||
so that we build a non-compact ``is_recompute=True`` index, spin up the
|
||||
standard HNSW embedding server, and measure how long incremental ``add`` takes
|
||||
when RNG pruning is fully enabled vs. partially/fully disabled.
|
||||
|
||||
Example usage (run from the repo root; downloads the model on first run)::
|
||||
|
||||
uv run -m benchmarks.update.bench_hnsw_rng_recompute \
|
||||
--index-path .leann/bench/leann-demo.leann \
|
||||
--runs 1
|
||||
|
||||
You can tweak the input documents with ``--initial-files`` / ``--update-files``
|
||||
if you want a larger or different workload, and change the embedding model via
|
||||
``--model-name``.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
import zmq
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_project_directory
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
from leann_backend_hnsw.convert_to_csr import prune_hnsw_embeddings_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def _find_repo_root() -> Path:
|
||||
"""Locate project root by walking up until pyproject.toml is found."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
# Fallback: assume repo is two levels up (../..)
|
||||
return current.parents[2]
|
||||
|
||||
|
||||
REPO_ROOT = _find_repo_root()
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from apps.chunking import create_text_chunks # noqa: E402
|
||||
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
DEFAULT_HNSW_LOG = Path(".leann/bench/hnsw_server.log")
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
if limit is not None:
|
||||
cleaned = cleaned[:limit]
|
||||
return cleaned
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
ef_construction: int,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=True,
|
||||
distance_metric=distance_metric,
|
||||
backend_kwargs={
|
||||
"distance_metric": distance_metric,
|
||||
"is_compact": False,
|
||||
"is_recompute": True,
|
||||
"efConstruction": ef_construction,
|
||||
},
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def prepare_new_chunks(paragraphs: list[str]) -> list[dict[str, Any]]:
|
||||
return [{"text": text, "metadata": {}} for text in paragraphs]
|
||||
|
||||
|
||||
def benchmark_update_with_mode(
|
||||
index_path: Path,
|
||||
new_chunks: list[dict[str, Any]],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
disable_forward_rng: bool,
|
||||
disable_reverse_rng: bool,
|
||||
server_port: int,
|
||||
add_timeout: int,
|
||||
ef_construction: int,
|
||||
) -> tuple[float, float]:
|
||||
meta_path = index_path.parent / f"{index_path.name}.meta.json"
|
||||
passages_file = index_path.parent / f"{index_path.name}.passages.jsonl"
|
||||
offset_file = index_path.parent / f"{index_path.name}.passages.idx"
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
with open(offset_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
existing_ids = set(offset_map.keys())
|
||||
|
||||
valid_chunks: list[dict[str, Any]] = []
|
||||
for chunk in new_chunks:
|
||||
text = chunk.get("text", "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
continue
|
||||
metadata = chunk.setdefault("metadata", {})
|
||||
passage_id = chunk.get("id") or metadata.get("id")
|
||||
if passage_id and passage_id in existing_ids:
|
||||
raise ValueError(f"Passage ID '{passage_id}' already exists in the index.")
|
||||
valid_chunks.append(chunk)
|
||||
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid chunks to append.")
|
||||
|
||||
texts_to_embed = [chunk["text"] for chunk in valid_chunks]
|
||||
embeddings = compute_embeddings(
|
||||
texts_to_embed,
|
||||
model_name,
|
||||
mode=embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
|
||||
embeddings = np.ascontiguousarray(embeddings, dtype=np.float32)
|
||||
if distance_metric == "cosine":
|
||||
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
embeddings = embeddings / norms
|
||||
|
||||
index = faiss.read_index(str(index_file))
|
||||
index.is_recompute = True
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
try:
|
||||
storage_index.ntotal = index.ntotal
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
index.hnsw.set_disable_rng_during_add(disable_forward_rng)
|
||||
index.hnsw.set_disable_reverse_prune(disable_reverse_rng)
|
||||
if ef_construction is not None:
|
||||
index.hnsw.efConstruction = ef_construction
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
applied_forward = getattr(index.hnsw, "disable_rng_during_add", None)
|
||||
applied_reverse = getattr(index.hnsw, "disable_reverse_prune", None)
|
||||
logger.info(
|
||||
"HNSW RNG config -> requested forward=%s, reverse=%s | applied forward=%s, reverse=%s",
|
||||
disable_forward_rng,
|
||||
disable_reverse_rng,
|
||||
applied_forward,
|
||||
applied_reverse,
|
||||
)
|
||||
|
||||
base_id = index.ntotal
|
||||
for offset, chunk in enumerate(valid_chunks):
|
||||
new_id = str(base_id + offset)
|
||||
chunk.setdefault("metadata", {})["id"] = new_id
|
||||
chunk["id"] = new_id
|
||||
|
||||
rollback_size = passages_file.stat().st_size if passages_file.exists() else 0
|
||||
offset_map_backup = offset_map.copy()
|
||||
|
||||
try:
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for chunk in valid_chunks:
|
||||
offset = f.tell()
|
||||
json.dump(
|
||||
{
|
||||
"id": chunk["id"],
|
||||
"text": chunk["text"],
|
||||
"metadata": chunk.get("metadata", {}),
|
||||
},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
f.write("\n")
|
||||
offset_map[chunk["id"]] = offset
|
||||
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
server_started, actual_port = server_manager.start_server(
|
||||
port=server_port,
|
||||
model_name=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=distance_metric,
|
||||
)
|
||||
if not server_started:
|
||||
raise RuntimeError("Failed to start embedding server.")
|
||||
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(actual_port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(actual_port)
|
||||
|
||||
_warmup_embedding_server(actual_port)
|
||||
|
||||
total_start = time.time()
|
||||
add_elapsed = 0.0
|
||||
|
||||
try:
|
||||
import signal
|
||||
|
||||
def _timeout_handler(signum, frame):
|
||||
raise TimeoutError("incremental add timed out")
|
||||
|
||||
if add_timeout > 0:
|
||||
signal.signal(signal.SIGALRM, _timeout_handler)
|
||||
signal.alarm(add_timeout)
|
||||
|
||||
add_start = time.time()
|
||||
for i in range(embeddings.shape[0]):
|
||||
index.add(1, faiss.swig_ptr(embeddings[i : i + 1]))
|
||||
add_elapsed = time.time() - add_start
|
||||
if add_timeout > 0:
|
||||
signal.alarm(0)
|
||||
faiss.write_index(index, str(index_file))
|
||||
finally:
|
||||
server_manager.stop_server()
|
||||
|
||||
except TimeoutError:
|
||||
raise
|
||||
except Exception:
|
||||
if passages_file.exists():
|
||||
with open(passages_file, "rb+") as f:
|
||||
f.truncate(rollback_size)
|
||||
with open(offset_file, "wb") as f:
|
||||
pickle.dump(offset_map_backup, f)
|
||||
raise
|
||||
|
||||
prune_hnsw_embeddings_inplace(str(index_file))
|
||||
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
# Reset toggles so the index on disk returns to baseline behaviour.
|
||||
try:
|
||||
index.hnsw.set_disable_rng_during_add(False)
|
||||
index.hnsw.set_disable_reverse_prune(False)
|
||||
except AttributeError:
|
||||
pass
|
||||
faiss.write_index(index, str(index_file))
|
||||
|
||||
total_elapsed = time.time() - total_start
|
||||
|
||||
return total_elapsed, add_elapsed
|
||||
|
||||
|
||||
def _total_zmq_nodes(log_path: Path) -> int:
|
||||
if not log_path.exists():
|
||||
return 0
|
||||
with log_path.open("r", encoding="utf-8") as log_file:
|
||||
text = log_file.read()
|
||||
return sum(int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", text))
|
||||
|
||||
|
||||
def _warmup_embedding_server(port: int) -> None:
|
||||
"""Send a dummy REQ so the embedding server loads its model."""
|
||||
ctx = zmq.Context()
|
||||
try:
|
||||
sock = ctx.socket(zmq.REQ)
|
||||
sock.setsockopt(zmq.LINGER, 0)
|
||||
sock.setsockopt(zmq.RCVTIMEO, 5000)
|
||||
sock.setsockopt(zmq.SNDTIMEO, 5000)
|
||||
sock.connect(f"tcp://127.0.0.1:{port}")
|
||||
payload = msgpack.packb(["__WARMUP__"], use_bin_type=True)
|
||||
sock.send(payload)
|
||||
try:
|
||||
sock.recv()
|
||||
except zmq.error.Again:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/bench/leann-demo.leann"),
|
||||
help="Output index base path (without extension).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
help="Files used to build the initial index.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
help="Files appended during the benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--runs", type=int, default=1, help="How many times to repeat each scenario."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
help="Embedding model used for build/update.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--embedding-mode",
|
||||
default="sentence-transformers",
|
||||
help="Embedding mode passed to LeannBuilder/embedding server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
default="mips",
|
||||
choices=["mips", "l2", "cosine"],
|
||||
help="Distance metric for HNSW backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ef-construction",
|
||||
type=int,
|
||||
default=200,
|
||||
help="efConstruction setting for initial build.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-port",
|
||||
type=int,
|
||||
default=5557,
|
||||
help="Port for the real embedding server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-initial",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Optional cap on initial passages (after chunking).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-updates",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Optional cap on update passages (after chunking).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add-timeout",
|
||||
type=int,
|
||||
default=900,
|
||||
help="Timeout in seconds for the incremental add loop (0 = no timeout).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-path",
|
||||
type=Path,
|
||||
default=Path("bench_latency.png"),
|
||||
help="Where to save the latency bar plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cap Y-axis (ms). Bars above are hatched and annotated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--broken-y",
|
||||
action="store_true",
|
||||
help="Use broken Y-axis (two stacked axes with gap). Overrides --cap-y unless both provided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lower-cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Lower axes upper bound for broken Y (ms). Default=1.1x second-highest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upper-start-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Upper axes lower bound for broken Y (ms). Default=1.2x second-highest.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/bench_results.csv"),
|
||||
help="Where to append per-scenario results as CSV.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||
update_paragraphs = load_chunks_from_files(args.update_files, args.max_updates)
|
||||
if not update_paragraphs:
|
||||
raise ValueError("No update passages found; please provide --update-files with content.")
|
||||
|
||||
update_chunks = prepare_new_chunks(update_paragraphs)
|
||||
ensure_index_dir(args.index_path)
|
||||
|
||||
scenarios = [
|
||||
("baseline", False, False, True),
|
||||
("no_cache_baseline", False, False, False),
|
||||
("disable_forward_rng", True, False, True),
|
||||
("disable_forward_and_reverse_rng", True, True, True),
|
||||
]
|
||||
|
||||
log_path = Path(os.environ.get("LEANN_HNSW_LOG_PATH", DEFAULT_HNSW_LOG))
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["LEANN_HNSW_LOG_PATH"] = str(log_path.resolve())
|
||||
os.environ.setdefault("LEANN_LOG_LEVEL", "INFO")
|
||||
|
||||
results_total: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_add: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_zmq: dict[str, list[int]] = {name: [] for name, *_ in scenarios}
|
||||
results_stageA: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_stageBC: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
results_ms_per_passage: dict[str, list[float]] = {name: [] for name, *_ in scenarios}
|
||||
|
||||
# CSV setup
|
||||
import csv
|
||||
|
||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||
csv_fields = [
|
||||
"run_id",
|
||||
"scenario",
|
||||
"cache_enabled",
|
||||
"ef_construction",
|
||||
"max_initial",
|
||||
"max_updates",
|
||||
"total_time_s",
|
||||
"add_only_s",
|
||||
"latency_ms_per_passage",
|
||||
"zmq_nodes",
|
||||
"stageA_time_s",
|
||||
"stageBC_time_s",
|
||||
"model_name",
|
||||
"embedding_mode",
|
||||
"distance_metric",
|
||||
]
|
||||
# Create CSV with header if missing
|
||||
if args.csv_path:
|
||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
for run in range(args.runs):
|
||||
print(f"\n=== Benchmark run {run + 1}/{args.runs} ===")
|
||||
for name, disable_forward, disable_reverse, cache_enabled in scenarios:
|
||||
print(f"\nScenario: {name}")
|
||||
cleanup_index_files(args.index_path)
|
||||
if log_path.exists():
|
||||
try:
|
||||
log_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
os.environ["LEANN_ZMQ_EMBED_CACHE"] = "1" if cache_enabled else "0"
|
||||
build_initial_index(
|
||||
args.index_path,
|
||||
initial_paragraphs,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
args.ef_construction,
|
||||
)
|
||||
|
||||
prev_size = log_path.stat().st_size if log_path.exists() else 0
|
||||
|
||||
try:
|
||||
total_elapsed, add_elapsed = benchmark_update_with_mode(
|
||||
args.index_path,
|
||||
update_chunks,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
disable_forward,
|
||||
disable_reverse,
|
||||
args.server_port,
|
||||
args.add_timeout,
|
||||
args.ef_construction,
|
||||
)
|
||||
except TimeoutError as exc:
|
||||
print(f"Scenario {name} timed out: {exc}")
|
||||
continue
|
||||
|
||||
curr_size = log_path.stat().st_size if log_path.exists() else 0
|
||||
if curr_size < prev_size:
|
||||
prev_size = 0
|
||||
zmq_count = 0
|
||||
if log_path.exists():
|
||||
with log_path.open("r", encoding="utf-8") as log_file:
|
||||
log_file.seek(prev_size)
|
||||
new_entries = log_file.read()
|
||||
zmq_count = sum(
|
||||
int(match) for match in re.findall(r"ZMQ received (\d+) node IDs", new_entries)
|
||||
)
|
||||
stageA = sum(
|
||||
float(x)
|
||||
for x in re.findall(r"Distance calculation E2E time: ([0-9.]+)s", new_entries)
|
||||
)
|
||||
stageBC = sum(
|
||||
float(x) for x in re.findall(r"ZMQ E2E time: ([0-9.]+)s", new_entries)
|
||||
)
|
||||
else:
|
||||
stageA = 0.0
|
||||
stageBC = 0.0
|
||||
|
||||
per_chunk = add_elapsed / len(update_chunks)
|
||||
print(
|
||||
f"Total time: {total_elapsed:.3f} s | add-only: {add_elapsed:.3f} s "
|
||||
f"for {len(update_chunks)} passages => {per_chunk * 1e3:.3f} ms/passage"
|
||||
)
|
||||
print(f"ZMQ node fetch total: {zmq_count}")
|
||||
results_total[name].append(total_elapsed)
|
||||
results_add[name].append(add_elapsed)
|
||||
results_zmq[name].append(zmq_count)
|
||||
results_ms_per_passage[name].append(per_chunk * 1e3)
|
||||
results_stageA[name].append(stageA)
|
||||
results_stageBC[name].append(stageBC)
|
||||
|
||||
# Append row to CSV
|
||||
if args.csv_path:
|
||||
row = {
|
||||
"run_id": run_id,
|
||||
"scenario": name,
|
||||
"cache_enabled": 1 if cache_enabled else 0,
|
||||
"ef_construction": args.ef_construction,
|
||||
"max_initial": args.max_initial,
|
||||
"max_updates": args.max_updates,
|
||||
"total_time_s": round(total_elapsed, 6),
|
||||
"add_only_s": round(add_elapsed, 6),
|
||||
"latency_ms_per_passage": round(per_chunk * 1e3, 6),
|
||||
"zmq_nodes": int(zmq_count),
|
||||
"stageA_time_s": round(stageA, 6),
|
||||
"stageBC_time_s": round(stageBC, 6),
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row)
|
||||
|
||||
print("\n=== Summary ===")
|
||||
for name in results_add:
|
||||
add_values = results_add[name]
|
||||
total_values = results_total[name]
|
||||
zmq_values = results_zmq[name]
|
||||
latency_values = results_ms_per_passage[name]
|
||||
if not add_values:
|
||||
print(f"{name}: no successful runs")
|
||||
continue
|
||||
avg_add = sum(add_values) / len(add_values)
|
||||
avg_total = sum(total_values) / len(total_values)
|
||||
avg_zmq = sum(zmq_values) / len(zmq_values) if zmq_values else 0.0
|
||||
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0.0
|
||||
runs = len(add_values)
|
||||
print(
|
||||
f"{name}: add-only avg {avg_add:.3f} s | total avg {avg_total:.3f} s "
|
||||
f"| ZMQ avg {avg_zmq:.1f} node fetches | latency {avg_latency:.2f} ms/passage over {runs} run(s)"
|
||||
)
|
||||
|
||||
if args.plot_path:
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
labels = [name for name, *_ in scenarios]
|
||||
values = [
|
||||
sum(results_ms_per_passage[name]) / len(results_ms_per_passage[name])
|
||||
if results_ms_per_passage[name]
|
||||
else 0.0
|
||||
for name in labels
|
||||
]
|
||||
|
||||
def _auto_cap(vals: list[float]) -> float | None:
|
||||
s = sorted(vals, reverse=True)
|
||||
if len(s) < 2:
|
||||
return None
|
||||
if s[1] > 0 and s[0] >= 2.5 * s[1]:
|
||||
return s[1] * 1.1
|
||||
return None
|
||||
|
||||
def _fmt_ms(v: float) -> str:
|
||||
return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.1f}"
|
||||
|
||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||
|
||||
if args.broken_y:
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.2, lower_cap * 1.02)
|
||||
)
|
||||
ymax = max(values) * 1.10 if values else 1.0
|
||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||
2,
|
||||
1,
|
||||
sharex=True,
|
||||
figsize=(7.4, 5.0),
|
||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.05},
|
||||
)
|
||||
x = list(range(len(labels)))
|
||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_bottom.set_ylim(0, lower_cap)
|
||||
ax_top.set_ylim(upper_start, ymax)
|
||||
for i, v in enumerate(values):
|
||||
if v <= lower_cap:
|
||||
ax_bottom.text(
|
||||
i,
|
||||
v + lower_cap * 0.02,
|
||||
_fmt_ms(v),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
else:
|
||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bottom.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False)
|
||||
ax_bottom.xaxis.tick_bottom()
|
||||
d = 0.015
|
||||
kwargs = {"transform": ax_top.transAxes, "color": "k", "clip_on": False}
|
||||
ax_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||
kwargs.update({"transform": ax_bottom.transAxes})
|
||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||
ax_bottom.set_xticks(range(len(labels)))
|
||||
ax_bottom.set_xticklabels(labels)
|
||||
ax = ax_bottom
|
||||
else:
|
||||
cap = args.cap_y or _auto_cap(values)
|
||||
plt.figure(figsize=(7.2, 4.2))
|
||||
ax = plt.gca()
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = []
|
||||
for i, (v, show) in enumerate(zip(values, show_vals)):
|
||||
b = ax.bar(i, show, color=colors[i], width=0.8)
|
||||
bars.append(b[0])
|
||||
if v > cap:
|
||||
bars[-1].set_hatch("//")
|
||||
ax.text(i, cap * 1.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
else:
|
||||
ax.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
_fmt_ms(v),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax.set_ylim(0, cap * 1.10)
|
||||
ax.plot(
|
||||
[0.02 - 0.02, 0.02 + 0.02],
|
||||
[0.98 + 0.02, 0.98 - 0.02],
|
||||
transform=ax.transAxes,
|
||||
color="k",
|
||||
lw=1,
|
||||
)
|
||||
ax.plot(
|
||||
[0.98 - 0.02, 0.98 + 0.02],
|
||||
[0.98 + 0.02, 0.98 - 0.02],
|
||||
transform=ax.transAxes,
|
||||
color="k",
|
||||
lw=1,
|
||||
)
|
||||
if any(v > cap for v in values):
|
||||
ax.legend(
|
||||
[bars[0]], ["capped"], fontsize=8, frameon=False, loc="upper right"
|
||||
)
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels)
|
||||
else:
|
||||
ax.bar(labels, values, color=colors[: len(labels)])
|
||||
for idx, val in enumerate(values):
|
||||
ax.text(idx, val + 1.0, f"{val:.1f}", ha="center", va="bottom")
|
||||
|
||||
plt.ylabel("Average add latency (ms per passage)")
|
||||
plt.title(f"Initial passages {args.max_initial}, updates {args.max_updates}")
|
||||
plt.tight_layout()
|
||||
plt.savefig(args.plot_path)
|
||||
print(f"Saved latency bar plot to {args.plot_path}")
|
||||
# ZMQ time split (Stage A vs B/C)
|
||||
try:
|
||||
plt.figure(figsize=(6, 4))
|
||||
a_vals = [sum(results_stageA[n]) / max(1, len(results_stageA[n])) for n in labels]
|
||||
bc_vals = [
|
||||
sum(results_stageBC[n]) / max(1, len(results_stageBC[n])) for n in labels
|
||||
]
|
||||
ind = range(len(labels))
|
||||
plt.bar(ind, a_vals, color="#4e79a7", label="Stage A distance (s)")
|
||||
plt.bar(
|
||||
ind, bc_vals, bottom=a_vals, color="#e15759", label="Stage B/C embed-by-id (s)"
|
||||
)
|
||||
plt.xticks(list(ind), labels, rotation=10)
|
||||
plt.ylabel("Server ZMQ time (s)")
|
||||
plt.title(
|
||||
f"ZMQ time split (initial {args.max_initial}, updates {args.max_updates})"
|
||||
)
|
||||
plt.legend()
|
||||
out2 = args.plot_path.with_name(
|
||||
args.plot_path.stem + "_zmq_split" + args.plot_path.suffix
|
||||
)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out2)
|
||||
print(f"Saved ZMQ time split plot to {out2}")
|
||||
except Exception as e:
|
||||
print("Failed to plot ZMQ split:", e)
|
||||
except ImportError:
|
||||
print("matplotlib not available; skipping plot generation")
|
||||
|
||||
# leave the last build on disk for inspection
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
benchmarks/update/bench_results.csv
Normal file
5
benchmarks/update/bench_results.csv
Normal file
@@ -0,0 +1,5 @@
|
||||
run_id,scenario,cache_enabled,ef_construction,max_initial,max_updates,total_time_s,add_only_s,latency_ms_per_passage,zmq_nodes,stageA_time_s,stageBC_time_s,model_name,embedding_mode,distance_metric
|
||||
20251024-133101,baseline,1,200,300,1,3.391856,1.120359,1120.359421,126,0.507821,0.601608,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,no_cache_baseline,0,200,300,1,34.941514,32.91376,32913.760185,4033,0.506933,32.159928,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,disable_forward_rng,1,200,300,1,2.746756,0.8202,820.200443,66,0.474354,0.338454,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-133101,disable_forward_and_reverse_rng,1,200,300,1,2.396566,0.521478,521.478415,1,0.508973,0.006938,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
|
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
704
benchmarks/update/bench_update_vs_offline_search.py
Normal file
@@ -0,0 +1,704 @@
|
||||
"""
|
||||
Compare two latency models for small incremental updates vs. search:
|
||||
|
||||
Scenario A (sequential update then search):
|
||||
- Build initial HNSW (is_recompute=True)
|
||||
- Start embedding server (ZMQ) for recompute
|
||||
- Add N passages one-by-one (each triggers recompute over ZMQ)
|
||||
- Then run a search query on the updated index
|
||||
- Report total time = sum(add_i) + search_time, with breakdowns
|
||||
|
||||
Scenario B (offline embeds + concurrent search; no graph updates):
|
||||
- Do NOT insert the N passages into the graph
|
||||
- In parallel: (1) compute embeddings for the N passages; (2) compute query
|
||||
embedding and run a search on the existing index
|
||||
- After both finish, compute similarity between the query embedding and the N
|
||||
new passage embeddings, merge with the index search results by score, and
|
||||
report time = max(embed_time, search_time) (i.e., no blocking on updates)
|
||||
|
||||
This script reuses the model/data loading conventions of
|
||||
examples/bench_hnsw_rng_recompute.py but focuses on end-to-end latency
|
||||
comparison for the two execution strategies above.
|
||||
|
||||
Example (from the repository root):
|
||||
uv run -m benchmarks.update.bench_update_vs_offline_search \
|
||||
--index-path .leann/bench/offline_vs_update.leann \
|
||||
--max-initial 300 --num-updates 5 --k 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import psutil # type: ignore
|
||||
from leann.api import LeannBuilder
|
||||
|
||||
if os.environ.get("LEANN_FORCE_CPU", "").lower() in ("1", "true", "yes"):
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
|
||||
|
||||
from leann.embedding_compute import compute_embeddings
|
||||
from leann.embedding_server_manager import EmbeddingServerManager
|
||||
from leann.registry import register_project_directory
|
||||
from leann_backend_hnsw import faiss # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
if not logging.getLogger().handlers:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def _find_repo_root() -> Path:
|
||||
"""Locate project root by walking up until pyproject.toml is found."""
|
||||
current = Path(__file__).resolve()
|
||||
for parent in current.parents:
|
||||
if (parent / "pyproject.toml").exists():
|
||||
return parent
|
||||
# Fallback: assume repo is two levels up (../..)
|
||||
return current.parents[2]
|
||||
|
||||
|
||||
REPO_ROOT = _find_repo_root()
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from apps.chunking import create_text_chunks # noqa: E402
|
||||
|
||||
DEFAULT_INITIAL_FILES = [
|
||||
REPO_ROOT / "data" / "2501.14312v1 (1).pdf",
|
||||
REPO_ROOT / "data" / "huawei_pangu.md",
|
||||
]
|
||||
DEFAULT_UPDATE_FILES = [REPO_ROOT / "data" / "2506.08276v1.pdf"]
|
||||
|
||||
|
||||
def load_chunks_from_files(paths: list[Path], limit: int | None = None) -> list[str]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
documents = []
|
||||
for path in paths:
|
||||
p = path.expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Input path not found: {p}")
|
||||
if p.is_dir():
|
||||
reader = SimpleDirectoryReader(str(p), recursive=False)
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=[str(p)])
|
||||
documents.extend(reader.load_data(show_progress=True))
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
chunks = create_text_chunks(
|
||||
documents,
|
||||
chunk_size=512,
|
||||
chunk_overlap=128,
|
||||
use_ast_chunking=False,
|
||||
)
|
||||
cleaned = [c for c in chunks if isinstance(c, str) and c.strip()]
|
||||
if limit is not None:
|
||||
cleaned = cleaned[:limit]
|
||||
return cleaned
|
||||
|
||||
|
||||
def ensure_index_dir(index_path: Path) -> None:
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def cleanup_index_files(index_path: Path) -> None:
|
||||
parent = index_path.parent
|
||||
if not parent.exists():
|
||||
return
|
||||
stem = index_path.stem
|
||||
for file in parent.glob(f"{stem}*"):
|
||||
if file.is_file():
|
||||
file.unlink()
|
||||
|
||||
|
||||
def build_initial_index(
|
||||
index_path: Path,
|
||||
paragraphs: list[str],
|
||||
model_name: str,
|
||||
embedding_mode: str,
|
||||
distance_metric: str,
|
||||
ef_construction: int,
|
||||
) -> None:
|
||||
builder = LeannBuilder(
|
||||
backend_name="hnsw",
|
||||
embedding_model=model_name,
|
||||
embedding_mode=embedding_mode,
|
||||
is_compact=False,
|
||||
is_recompute=True,
|
||||
distance_metric=distance_metric,
|
||||
backend_kwargs={
|
||||
"distance_metric": distance_metric,
|
||||
"is_compact": False,
|
||||
"is_recompute": True,
|
||||
"efConstruction": ef_construction,
|
||||
},
|
||||
)
|
||||
for idx, passage in enumerate(paragraphs):
|
||||
builder.add_text(passage, metadata={"id": str(idx)})
|
||||
builder.build_index(str(index_path))
|
||||
|
||||
|
||||
def _maybe_norm_cosine(vecs: np.ndarray, metric: str) -> np.ndarray:
|
||||
if metric == "cosine":
|
||||
vecs = np.ascontiguousarray(vecs, dtype=np.float32)
|
||||
norms = np.linalg.norm(vecs, axis=1, keepdims=True)
|
||||
norms[norms == 0] = 1
|
||||
vecs = vecs / norms
|
||||
return vecs
|
||||
|
||||
|
||||
def _read_index_for_search(index_path: Path) -> Any:
|
||||
index_file = index_path.parent / f"{index_path.stem}.index"
|
||||
# Force-disable experimental disk cache when loading the index so that
|
||||
# incremental benchmarks don't pick up stale top-degree bitmaps.
|
||||
cfg = faiss.HNSWIndexConfig()
|
||||
cfg.is_recompute = True
|
||||
if hasattr(cfg, "disk_cache_ratio"):
|
||||
cfg.disk_cache_ratio = 0.0
|
||||
if hasattr(cfg, "external_storage_path"):
|
||||
cfg.external_storage_path = None
|
||||
io_flags = getattr(faiss, "IO_FLAG_MMAP", 0)
|
||||
index = faiss.read_index(str(index_file), io_flags, cfg)
|
||||
# ensure recompute mode persists after reload
|
||||
try:
|
||||
index.is_recompute = True
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
actual_ntotal = index.hnsw.levels.size()
|
||||
except AttributeError:
|
||||
actual_ntotal = index.ntotal
|
||||
if actual_ntotal != index.ntotal:
|
||||
print(
|
||||
f"[bench_update_vs_offline_search] Correcting ntotal from {index.ntotal} to {actual_ntotal}",
|
||||
flush=True,
|
||||
)
|
||||
index.ntotal = actual_ntotal
|
||||
if getattr(index, "storage", None) is None:
|
||||
if index.metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
storage_index = faiss.IndexFlatIP(index.d)
|
||||
else:
|
||||
storage_index = faiss.IndexFlatL2(index.d)
|
||||
index.storage = storage_index
|
||||
index.own_fields = True
|
||||
return index
|
||||
|
||||
|
||||
def _append_passages_for_updates(
|
||||
meta_path: Path,
|
||||
start_id: int,
|
||||
texts: list[str],
|
||||
) -> list[str]:
|
||||
"""Append update passages so the embedding server can serve recompute fetches."""
|
||||
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
index_dir = meta_path.parent
|
||||
meta_name = meta_path.name
|
||||
if not meta_name.endswith(".meta.json"):
|
||||
raise ValueError(f"Unexpected meta filename: {meta_path}")
|
||||
index_base = meta_name[: -len(".meta.json")]
|
||||
|
||||
passages_file = index_dir / f"{index_base}.passages.jsonl"
|
||||
offsets_file = index_dir / f"{index_base}.passages.idx"
|
||||
|
||||
if not passages_file.exists() or not offsets_file.exists():
|
||||
raise FileNotFoundError(
|
||||
"Passage store missing; cannot register update passages for recompute mode."
|
||||
)
|
||||
|
||||
with open(offsets_file, "rb") as f:
|
||||
offset_map: dict[str, int] = pickle.load(f)
|
||||
|
||||
assigned_ids: list[str] = []
|
||||
with open(passages_file, "a", encoding="utf-8") as f:
|
||||
for i, text in enumerate(texts):
|
||||
passage_id = str(start_id + i)
|
||||
offset = f.tell()
|
||||
json.dump({"id": passage_id, "text": text, "metadata": {}}, f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
offset_map[passage_id] = offset
|
||||
assigned_ids.append(passage_id)
|
||||
|
||||
with open(offsets_file, "wb") as f:
|
||||
pickle.dump(offset_map, f)
|
||||
|
||||
try:
|
||||
with open(meta_path, encoding="utf-8") as f:
|
||||
meta = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
meta = {}
|
||||
meta["total_passages"] = len(offset_map)
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
return assigned_ids
|
||||
|
||||
|
||||
def _search(index: Any, q: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
q = np.ascontiguousarray(q, dtype=np.float32)
|
||||
distances = np.zeros((1, k), dtype=np.float32)
|
||||
indices = np.zeros((1, k), dtype=np.int64)
|
||||
index.search(
|
||||
1,
|
||||
faiss.swig_ptr(q),
|
||||
k,
|
||||
faiss.swig_ptr(distances),
|
||||
faiss.swig_ptr(indices),
|
||||
)
|
||||
return distances[0], indices[0]
|
||||
|
||||
|
||||
def _score_for_metric(dist: float, metric: str) -> float:
|
||||
# Convert FAISS distance to a "higher is better" score
|
||||
if metric in ("mips", "cosine"):
|
||||
return float(dist)
|
||||
# l2 distance (smaller better) -> negative distance as score
|
||||
return -float(dist)
|
||||
|
||||
|
||||
def _merge_results(
|
||||
index_results: tuple[np.ndarray, np.ndarray],
|
||||
offline_scores: list[tuple[int, float]],
|
||||
k: int,
|
||||
metric: str,
|
||||
) -> list[tuple[str, float]]:
|
||||
distances, indices = index_results
|
||||
merged: list[tuple[str, float]] = []
|
||||
for distance, idx in zip(distances.tolist(), indices.tolist()):
|
||||
merged.append((f"idx:{idx}", _score_for_metric(distance, metric)))
|
||||
for j, s in offline_scores:
|
||||
merged.append((f"offline:{j}", s))
|
||||
merged.sort(key=lambda x: x[1], reverse=True)
|
||||
return merged[:k]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenarioResult:
|
||||
name: str
|
||||
update_total_s: float
|
||||
search_s: float
|
||||
overall_s: float
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--index-path",
|
||||
type=Path,
|
||||
default=Path(".leann/bench/offline-vs-update.leann"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initial-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_INITIAL_FILES,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-files",
|
||||
nargs="*",
|
||||
type=Path,
|
||||
default=DEFAULT_UPDATE_FILES,
|
||||
)
|
||||
parser.add_argument("--max-initial", type=int, default=300)
|
||||
parser.add_argument("--num-updates", type=int, default=5)
|
||||
parser.add_argument("--k", type=int, default=10, help="Top-k for search/merge")
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
type=str,
|
||||
default="neural network",
|
||||
help="Query text used for the search benchmark.",
|
||||
)
|
||||
parser.add_argument("--server-port", type=int, default=5557)
|
||||
parser.add_argument("--add-timeout", type=int, default=600)
|
||||
parser.add_argument("--model-name", default="sentence-transformers/all-MiniLM-L6-v2")
|
||||
parser.add_argument("--embedding-mode", default="sentence-transformers")
|
||||
parser.add_argument(
|
||||
"--distance-metric",
|
||||
default="mips",
|
||||
choices=["mips", "l2", "cosine"],
|
||||
)
|
||||
parser.add_argument("--ef-construction", type=int, default=200)
|
||||
parser.add_argument(
|
||||
"--only",
|
||||
choices=["A", "B", "both"],
|
||||
default="both",
|
||||
help="Run only Scenario A, Scenario B, or both",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--csv-path",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||
help="Where to append results (CSV).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
register_project_directory(REPO_ROOT)
|
||||
|
||||
# Load data
|
||||
initial_paragraphs = load_chunks_from_files(args.initial_files, args.max_initial)
|
||||
update_paragraphs = load_chunks_from_files(args.update_files, None)
|
||||
if not update_paragraphs:
|
||||
raise ValueError("No update passages loaded from --update-files")
|
||||
update_paragraphs = update_paragraphs[: args.num_updates]
|
||||
if len(update_paragraphs) < args.num_updates:
|
||||
raise ValueError(
|
||||
f"Not enough update passages ({len(update_paragraphs)}) for --num-updates={args.num_updates}"
|
||||
)
|
||||
|
||||
ensure_index_dir(args.index_path)
|
||||
cleanup_index_files(args.index_path)
|
||||
|
||||
# Build initial index
|
||||
build_initial_index(
|
||||
args.index_path,
|
||||
initial_paragraphs,
|
||||
args.model_name,
|
||||
args.embedding_mode,
|
||||
args.distance_metric,
|
||||
args.ef_construction,
|
||||
)
|
||||
|
||||
# Prepare index object and meta
|
||||
meta_path = args.index_path.parent / f"{args.index_path.name}.meta.json"
|
||||
index = _read_index_for_search(args.index_path)
|
||||
|
||||
# CSV setup
|
||||
run_id = time.strftime("%Y%m%d-%H%M%S")
|
||||
if args.csv_path:
|
||||
args.csv_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
csv_fields = [
|
||||
"run_id",
|
||||
"scenario",
|
||||
"max_initial",
|
||||
"num_updates",
|
||||
"k",
|
||||
"total_time_s",
|
||||
"add_total_s",
|
||||
"search_time_s",
|
||||
"emb_time_s",
|
||||
"makespan_s",
|
||||
"model_name",
|
||||
"embedding_mode",
|
||||
"distance_metric",
|
||||
]
|
||||
if not args.csv_path.exists() or args.csv_path.stat().st_size == 0:
|
||||
with args.csv_path.open("w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writeheader()
|
||||
|
||||
# Debug: list existing HNSW server PIDs before starting
|
||||
try:
|
||||
existing = [
|
||||
p
|
||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||
if any(
|
||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||
for arg in (p.info.get("cmdline") or [])
|
||||
)
|
||||
]
|
||||
if existing:
|
||||
print("[debug] Found existing hnsw_embedding_server processes before run:")
|
||||
for p in existing:
|
||||
print(f"[debug] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}")
|
||||
except Exception as _e:
|
||||
pass
|
||||
|
||||
add_total = 0.0
|
||||
search_after_add = 0.0
|
||||
total_seq = 0.0
|
||||
port_a = None
|
||||
if args.only in ("A", "both"):
|
||||
# Scenario A: sequential update then search
|
||||
start_id = index.ntotal
|
||||
assigned_ids = _append_passages_for_updates(meta_path, start_id, update_paragraphs)
|
||||
if assigned_ids:
|
||||
logger.debug(
|
||||
"Registered %d update passages starting at id %s",
|
||||
len(assigned_ids),
|
||||
assigned_ids[0],
|
||||
)
|
||||
server_manager = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
ok, port = server_manager.start_server(
|
||||
port=args.server_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
if not ok:
|
||||
raise RuntimeError("Failed to start embedding server")
|
||||
try:
|
||||
# Set ZMQ port for recompute mode
|
||||
if hasattr(index.hnsw, "set_zmq_port"):
|
||||
index.hnsw.set_zmq_port(port)
|
||||
elif hasattr(index, "set_zmq_port"):
|
||||
index.set_zmq_port(port)
|
||||
|
||||
# Start A overall timer BEFORE computing update embeddings
|
||||
t0 = time.time()
|
||||
|
||||
# Compute embeddings for updates (counted into A's overall)
|
||||
t_emb0 = time.time()
|
||||
upd_embs = compute_embeddings(
|
||||
update_paragraphs,
|
||||
args.model_name,
|
||||
mode=args.embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
emb_time_updates = time.time() - t_emb0
|
||||
upd_embs = np.asarray(upd_embs, dtype=np.float32)
|
||||
upd_embs = _maybe_norm_cosine(upd_embs, args.distance_metric)
|
||||
|
||||
# Perform sequential adds
|
||||
for i in range(upd_embs.shape[0]):
|
||||
t_add0 = time.time()
|
||||
index.add(1, faiss.swig_ptr(upd_embs[i : i + 1]))
|
||||
add_total += time.time() - t_add0
|
||||
# Don't persist index after adds to avoid contaminating Scenario B
|
||||
# index_file = args.index_path.parent / f"{args.index_path.stem}.index"
|
||||
# faiss.write_index(index, str(index_file))
|
||||
|
||||
# Search after updates
|
||||
q_emb = compute_embeddings(
|
||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
q_emb = np.asarray(q_emb, dtype=np.float32)
|
||||
q_emb = _maybe_norm_cosine(q_emb, args.distance_metric)
|
||||
|
||||
# Warm up search with a dummy query first
|
||||
print("[DEBUG] Warming up search...")
|
||||
_ = _search(index, q_emb, 1)
|
||||
|
||||
t_s0 = time.time()
|
||||
D_upd, I_upd = _search(index, q_emb, args.k)
|
||||
search_after_add = time.time() - t_s0
|
||||
total_seq = time.time() - t0
|
||||
finally:
|
||||
server_manager.stop_server()
|
||||
port_a = port
|
||||
|
||||
print("\n=== Scenario A: update->search (sequential) ===")
|
||||
# emb_time_updates is defined only when A runs
|
||||
try:
|
||||
_emb_a = emb_time_updates
|
||||
except NameError:
|
||||
_emb_a = 0.0
|
||||
print(
|
||||
f"Adds: {args.num_updates} passages; embeds={_emb_a:.3f}s; add_total={add_total:.3f}s; "
|
||||
f"search={search_after_add:.3f}s; overall={total_seq:.3f}s"
|
||||
)
|
||||
# CSV row for A
|
||||
if args.csv_path:
|
||||
row_a = {
|
||||
"run_id": run_id,
|
||||
"scenario": "A",
|
||||
"max_initial": args.max_initial,
|
||||
"num_updates": args.num_updates,
|
||||
"k": args.k,
|
||||
"total_time_s": round(total_seq, 6),
|
||||
"add_total_s": round(add_total, 6),
|
||||
"search_time_s": round(search_after_add, 6),
|
||||
"emb_time_s": round(_emb_a, 6),
|
||||
"makespan_s": 0.0,
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row_a)
|
||||
|
||||
# Verify server cleanup
|
||||
try:
|
||||
# short sleep to allow signal handling to finish
|
||||
time.sleep(0.5)
|
||||
leftovers = [
|
||||
p
|
||||
for p in psutil.process_iter(attrs=["pid", "cmdline"])
|
||||
if any(
|
||||
isinstance(arg, str) and "leann_backend_hnsw.hnsw_embedding_server" in arg
|
||||
for arg in (p.info.get("cmdline") or [])
|
||||
)
|
||||
]
|
||||
if leftovers:
|
||||
print("[warn] hnsw_embedding_server process(es) still alive after A-stop:")
|
||||
for p in leftovers:
|
||||
print(
|
||||
f"[warn] PID={p.info['pid']} cmd={' '.join(p.info.get('cmdline') or [])}"
|
||||
)
|
||||
else:
|
||||
print("[debug] server cleanup confirmed: no hnsw_embedding_server found")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Scenario B: offline embeds + concurrent search (no graph updates)
|
||||
if args.only in ("B", "both"):
|
||||
# ensure a server is available for recompute search
|
||||
server_manager_b = EmbeddingServerManager(
|
||||
backend_module_name="leann_backend_hnsw.hnsw_embedding_server"
|
||||
)
|
||||
requested_port = args.server_port if port_a is None else port_a
|
||||
ok_b, port_b = server_manager_b.start_server(
|
||||
port=requested_port,
|
||||
model_name=args.model_name,
|
||||
embedding_mode=args.embedding_mode,
|
||||
passages_file=str(meta_path),
|
||||
distance_metric=args.distance_metric,
|
||||
)
|
||||
if not ok_b:
|
||||
raise RuntimeError("Failed to start embedding server for Scenario B")
|
||||
|
||||
# Wait for server to fully initialize
|
||||
print("[DEBUG] Waiting 2s for embedding server to fully initialize...")
|
||||
time.sleep(2)
|
||||
|
||||
try:
|
||||
# Read the index first
|
||||
index_no_update = _read_index_for_search(args.index_path) # unchanged index
|
||||
|
||||
# Then configure ZMQ port on the correct index object
|
||||
if hasattr(index_no_update.hnsw, "set_zmq_port"):
|
||||
index_no_update.hnsw.set_zmq_port(port_b)
|
||||
elif hasattr(index_no_update, "set_zmq_port"):
|
||||
index_no_update.set_zmq_port(port_b)
|
||||
|
||||
# Warmup the embedding model before benchmarking (do this for both --only B and --only both)
|
||||
# This ensures fair comparison as Scenario A has warmed up the model during update embeddings
|
||||
logger.info("Warming up embedding model for Scenario B...")
|
||||
_ = compute_embeddings(
|
||||
["warmup text"], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
|
||||
# Prepare worker A: compute embeddings for the same N passages
|
||||
emb_time = 0.0
|
||||
updates_embs_offline: np.ndarray | None = None
|
||||
|
||||
def _worker_emb():
|
||||
nonlocal emb_time, updates_embs_offline
|
||||
t = time.time()
|
||||
updates_embs_offline = compute_embeddings(
|
||||
update_paragraphs,
|
||||
args.model_name,
|
||||
mode=args.embedding_mode,
|
||||
is_build=False,
|
||||
batch_size=16,
|
||||
)
|
||||
emb_time = time.time() - t
|
||||
|
||||
# Pre-compute query embedding and warm up search outside of timed section.
|
||||
q_vec = compute_embeddings(
|
||||
[args.query], args.model_name, mode=args.embedding_mode, is_build=False
|
||||
)
|
||||
q_vec = np.asarray(q_vec, dtype=np.float32)
|
||||
q_vec = _maybe_norm_cosine(q_vec, args.distance_metric)
|
||||
print("[DEBUG B] Warming up search...")
|
||||
_ = _search(index_no_update, q_vec, 1)
|
||||
|
||||
# Worker B: timed search on the warmed index
|
||||
search_time = 0.0
|
||||
offline_elapsed = 0.0
|
||||
index_results: tuple[np.ndarray, np.ndarray] | None = None
|
||||
|
||||
def _worker_search():
|
||||
nonlocal search_time, index_results
|
||||
t = time.time()
|
||||
distances, indices = _search(index_no_update, q_vec, args.k)
|
||||
search_time = time.time() - t
|
||||
index_results = (distances, indices)
|
||||
|
||||
# Run two workers concurrently
|
||||
t0 = time.time()
|
||||
th1 = threading.Thread(target=_worker_emb)
|
||||
th2 = threading.Thread(target=_worker_search)
|
||||
th1.start()
|
||||
th2.start()
|
||||
th1.join()
|
||||
th2.join()
|
||||
offline_elapsed = time.time() - t0
|
||||
|
||||
# For mixing: compute query vs. offline update similarities (pure client-side)
|
||||
offline_scores: list[tuple[int, float]] = []
|
||||
if updates_embs_offline is not None:
|
||||
upd2 = np.asarray(updates_embs_offline, dtype=np.float32)
|
||||
upd2 = _maybe_norm_cosine(upd2, args.distance_metric)
|
||||
# For mips/cosine, score = dot; for l2, score = -||x-y||^2
|
||||
for j in range(upd2.shape[0]):
|
||||
if args.distance_metric in ("mips", "cosine"):
|
||||
s = float(np.dot(q_vec[0], upd2[j]))
|
||||
else:
|
||||
diff = q_vec[0] - upd2[j]
|
||||
s = -float(np.dot(diff, diff))
|
||||
offline_scores.append((j, s))
|
||||
|
||||
merged_topk = (
|
||||
_merge_results(index_results, offline_scores, args.k, args.distance_metric)
|
||||
if index_results
|
||||
else []
|
||||
)
|
||||
|
||||
print("\n=== Scenario B: offline embeds + concurrent search (no add) ===")
|
||||
print(
|
||||
f"embeddings({args.num_updates})={emb_time:.3f}s; search={search_time:.3f}s; makespan≈{offline_elapsed:.3f}s (≈max)"
|
||||
)
|
||||
if merged_topk:
|
||||
preview = ", ".join([f"{lab}:{score:.3f}" for lab, score in merged_topk[:5]])
|
||||
print(f"Merged top-5 preview: {preview}")
|
||||
# CSV row for B
|
||||
if args.csv_path:
|
||||
row_b = {
|
||||
"run_id": run_id,
|
||||
"scenario": "B",
|
||||
"max_initial": args.max_initial,
|
||||
"num_updates": args.num_updates,
|
||||
"k": args.k,
|
||||
"total_time_s": 0.0,
|
||||
"add_total_s": 0.0,
|
||||
"search_time_s": round(search_time, 6),
|
||||
"emb_time_s": round(emb_time, 6),
|
||||
"makespan_s": round(offline_elapsed, 6),
|
||||
"model_name": args.model_name,
|
||||
"embedding_mode": args.embedding_mode,
|
||||
"distance_metric": args.distance_metric,
|
||||
}
|
||||
with args.csv_path.open("a", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
||||
writer.writerow(row_b)
|
||||
|
||||
finally:
|
||||
server_manager_b.stop_server()
|
||||
|
||||
# Summary
|
||||
print("\n=== Summary ===")
|
||||
msg_a = (
|
||||
f"A: seq-add+search overall={total_seq:.3f}s (adds={add_total:.3f}s, search={search_after_add:.3f}s)"
|
||||
if args.only in ("A", "both")
|
||||
else "A: skipped"
|
||||
)
|
||||
msg_b = (
|
||||
f"B: offline+concurrent overall≈{offline_elapsed:.3f}s (emb={emb_time:.3f}s, search={search_time:.3f}s)"
|
||||
if args.only in ("B", "both")
|
||||
else "B: skipped"
|
||||
)
|
||||
print(msg_a + "\n" + msg_b)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
benchmarks/update/offline_vs_update.csv
Normal file
5
benchmarks/update/offline_vs_update.csv
Normal file
@@ -0,0 +1,5 @@
|
||||
run_id,scenario,max_initial,num_updates,k,total_time_s,add_total_s,search_time_s,emb_time_s,makespan_s,model_name,embedding_mode,distance_metric
|
||||
20251024-141607,A,300,1,10,3.273957,3.050168,0.097825,0.017339,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251024-141607,B,300,1,10,0.0,0.0,0.111892,0.007869,0.112635,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251025-160652,A,300,5,10,5.061945,4.805962,0.123271,0.015008,0.0,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
20251025-160652,B,300,5,10,0.0,0.0,0.101809,0.008817,0.102447,sentence-transformers/all-MiniLM-L6-v2,sentence-transformers,mips
|
||||
|
645
benchmarks/update/plot_bench_results.py
Normal file
645
benchmarks/update/plot_bench_results.py
Normal file
@@ -0,0 +1,645 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Plot latency bars from the benchmark CSV produced by
|
||||
benchmarks/update/bench_hnsw_rng_recompute.py.
|
||||
|
||||
If you also provide an offline_vs_update.csv via --csv-right
|
||||
(from benchmarks/update/bench_update_vs_offline_search.py), this script will
|
||||
output a side-by-side figure:
|
||||
- Left: ms/passage bars (four RNG scenarios).
|
||||
- Right: seconds bars (Scenario A seq add+search vs Scenario B offline+search).
|
||||
|
||||
Usage:
|
||||
uv run python benchmarks/update/plot_bench_results.py \
|
||||
--csv benchmarks/update/bench_results.csv \
|
||||
--out benchmarks/update/bench_latency_from_csv.png
|
||||
|
||||
The script selects the latest run_id in the CSV and plots four bars for
|
||||
the default scenarios:
|
||||
- baseline
|
||||
- no_cache_baseline
|
||||
- disable_forward_rng
|
||||
- disable_forward_and_reverse_rng
|
||||
|
||||
If multiple rows exist per scenario for that run_id, the script averages
|
||||
their latency_ms_per_passage values.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_SCENARIOS = [
|
||||
"no_cache_baseline",
|
||||
"baseline",
|
||||
"disable_forward_rng",
|
||||
"disable_forward_and_reverse_rng",
|
||||
]
|
||||
|
||||
SCENARIO_LABELS = {
|
||||
"baseline": "+ Cache",
|
||||
"no_cache_baseline": "Naive \n Recompute",
|
||||
"disable_forward_rng": "+ w/o \n Fwd RNG",
|
||||
"disable_forward_and_reverse_rng": "+ w/o \n Bwd RNG",
|
||||
}
|
||||
|
||||
# Paper-style colors and hatches for scenarios
|
||||
SCENARIO_STYLES = {
|
||||
"no_cache_baseline": {"edgecolor": "dimgrey", "hatch": "/////"},
|
||||
"baseline": {"edgecolor": "#63B8B6", "hatch": "xxxxx"},
|
||||
"disable_forward_rng": {"edgecolor": "green", "hatch": "....."},
|
||||
"disable_forward_and_reverse_rng": {"edgecolor": "tomato", "hatch": "\\\\\\\\\\"},
|
||||
}
|
||||
|
||||
|
||||
def load_latest_run(csv_path: Path):
|
||||
rows = []
|
||||
with csv_path.open("r", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
rows.append(row)
|
||||
if not rows:
|
||||
raise SystemExit("CSV is empty: no rows to plot")
|
||||
# Choose latest run_id lexicographically (YYYYMMDD-HHMMSS)
|
||||
run_ids = [r.get("run_id", "") for r in rows]
|
||||
latest = max(run_ids)
|
||||
latest_rows = [r for r in rows if r.get("run_id", "") == latest]
|
||||
if not latest_rows:
|
||||
# Fallback: take last 4 rows
|
||||
latest_rows = rows[-4:]
|
||||
latest = latest_rows[-1].get("run_id", "unknown")
|
||||
return latest, latest_rows
|
||||
|
||||
|
||||
def aggregate_latency(rows):
|
||||
acc = defaultdict(list)
|
||||
for r in rows:
|
||||
sc = r.get("scenario", "")
|
||||
try:
|
||||
val = float(r.get("latency_ms_per_passage", "nan"))
|
||||
except ValueError:
|
||||
continue
|
||||
acc[sc].append(val)
|
||||
avg = {k: (sum(v) / len(v) if v else 0.0) for k, v in acc.items()}
|
||||
return avg
|
||||
|
||||
|
||||
def _auto_cap(values: list[float]) -> float | None:
|
||||
if not values:
|
||||
return None
|
||||
sorted_vals = sorted(values, reverse=True)
|
||||
if len(sorted_vals) < 2:
|
||||
return None
|
||||
max_v, second = sorted_vals[0], sorted_vals[1]
|
||||
if second <= 0:
|
||||
return None
|
||||
# If the tallest bar dwarfs the second by 2.5x+, cap near the second
|
||||
if max_v >= 2.5 * second:
|
||||
return second * 1.1
|
||||
return None
|
||||
|
||||
|
||||
def _add_break_marker(ax, y, rel_x0=0.02, rel_x1=0.98, size=0.02):
|
||||
# Draw small diagonal ticks near left/right to signal cap
|
||||
x0, x1 = rel_x0, rel_x1
|
||||
ax.plot([x0 - size, x0 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||
ax.plot([x1 - size, x1 + size], [y + size, y - size], transform=ax.transAxes, color="k", lw=1)
|
||||
|
||||
|
||||
def _fmt_ms(v: float) -> str:
|
||||
if v >= 1000:
|
||||
return f"{v / 1000:.1f}k"
|
||||
return f"{v:.1f}"
|
||||
|
||||
|
||||
def main():
|
||||
# Set LaTeX style for paper figures (matching paper_fig.py)
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.rcParams["font.family"] = "Helvetica"
|
||||
plt.rcParams["ytick.direction"] = "in"
|
||||
plt.rcParams["hatch.linewidth"] = 1.5
|
||||
plt.rcParams["font.weight"] = "bold"
|
||||
plt.rcParams["axes.labelweight"] = "bold"
|
||||
plt.rcParams["text.usetex"] = True
|
||||
|
||||
ap = argparse.ArgumentParser(description=__doc__)
|
||||
ap.add_argument(
|
||||
"--csv",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/bench_results.csv"),
|
||||
help="Path to results CSV (defaults to bench_results.csv)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--out",
|
||||
type=Path,
|
||||
default=Path("add_ablation.pdf"),
|
||||
help="Output image path",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--csv-right",
|
||||
type=Path,
|
||||
default=Path("benchmarks/update/offline_vs_update.csv"),
|
||||
help="Optional: offline_vs_update.csv to render right subplot (A vs B)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Cap Y-axis at this ms value; bars above are hatched and annotated.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--no-auto-cap",
|
||||
action="store_true",
|
||||
help="Disable auto-cap heuristic when --cap-y is not provided.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--broken-y",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use a broken Y-axis (two stacked axes with a gap). Overrides --cap-y unless both provided.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--lower-cap-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Lower axes upper bound for broken Y (ms). Default = 1.1x second-highest.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--upper-start-y",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Upper axes lower bound for broken Y (ms). Default = 1.2x second-highest.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
latest_run, latest_rows = load_latest_run(args.csv)
|
||||
avg = aggregate_latency(latest_rows)
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except Exception as e:
|
||||
raise SystemExit(f"matplotlib not available: {e}")
|
||||
|
||||
scenarios = DEFAULT_SCENARIOS
|
||||
values = [avg.get(name, 0.0) for name in scenarios]
|
||||
labels = [SCENARIO_LABELS.get(name, name) for name in scenarios]
|
||||
colors = ["#4e79a7", "#f28e2c", "#e15759", "#76b7b2"]
|
||||
|
||||
# If right CSV is provided, build side-by-side figure
|
||||
if args.csv_right is not None:
|
||||
try:
|
||||
right_rows_all = []
|
||||
with args.csv_right.open("r", encoding="utf-8") as f:
|
||||
rreader = csv.DictReader(f)
|
||||
right_rows_all = list(rreader)
|
||||
if right_rows_all:
|
||||
r_latest = max(r.get("run_id", "") for r in right_rows_all)
|
||||
right_rows = [r for r in right_rows_all if r.get("run_id", "") == r_latest]
|
||||
else:
|
||||
r_latest = None
|
||||
right_rows = []
|
||||
except Exception:
|
||||
r_latest = None
|
||||
right_rows = []
|
||||
|
||||
a_total = 0.0
|
||||
b_makespan = 0.0
|
||||
for r in right_rows:
|
||||
sc = (r.get("scenario", "") or "").strip().upper()
|
||||
if sc == "A":
|
||||
try:
|
||||
a_total = float(r.get("total_time_s", 0.0))
|
||||
except Exception:
|
||||
pass
|
||||
elif sc == "B":
|
||||
try:
|
||||
b_makespan = float(r.get("makespan_s", 0.0))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import gridspec
|
||||
|
||||
# Left subplot (reuse current style, with optional cap)
|
||||
cap = args.cap_y
|
||||
if cap is None and not args.no_auto_cap:
|
||||
cap = _auto_cap(values)
|
||||
x = list(range(len(labels)))
|
||||
|
||||
if args.broken_y:
|
||||
# Use broken axis for left subplot
|
||||
# Auto-adjust width ratios: left has 4 bars, right has 2 bars
|
||||
fig = plt.figure(figsize=(4.8, 1.8)) # Scaled down to 80%
|
||||
gs = gridspec.GridSpec(
|
||||
2, 2, height_ratios=[1, 3], width_ratios=[1.5, 1], hspace=0.08, wspace=0.35
|
||||
)
|
||||
ax_left_top = fig.add_subplot(gs[0, 0])
|
||||
ax_left_bottom = fig.add_subplot(gs[1, 0], sharex=ax_left_top)
|
||||
ax_right = fig.add_subplot(gs[:, 1])
|
||||
|
||||
# Determine break points
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = (
|
||||
args.lower_cap_y if args.lower_cap_y is not None else second * 1.4
|
||||
) # Increased to show more range
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.5, lower_cap * 1.02)
|
||||
)
|
||||
ymax = (
|
||||
max(values) * 1.90 if values else 1.0
|
||||
) # Increase headroom to 1.90 for text label and tick range
|
||||
|
||||
# Draw bars on both axes
|
||||
ax_left_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_left_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
|
||||
# Set limits
|
||||
ax_left_bottom.set_ylim(0, lower_cap)
|
||||
ax_left_top.set_ylim(upper_start, ymax)
|
||||
|
||||
# Annotate values (convert ms to s)
|
||||
values_s = [v / 1000.0 for v in values]
|
||||
lower_cap_s = lower_cap / 1000.0
|
||||
upper_start_s = upper_start / 1000.0
|
||||
ymax_s = ymax / 1000.0
|
||||
|
||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||
|
||||
# Redraw bars with s values (paper style: white fill + colored edge + hatch)
|
||||
ax_left_bottom.clear()
|
||||
ax_left_top.clear()
|
||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||
for i, (scenario_name, v) in enumerate(zip(scenarios, values_s)):
|
||||
style = SCENARIO_STYLES.get(scenario_name, {"edgecolor": "black", "hatch": ""})
|
||||
# Draw in bottom axis for all bars
|
||||
ax_left_bottom.bar(
|
||||
i,
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
# Only draw in top axis if the bar is tall enough to reach the upper range
|
||||
if v > upper_start_s:
|
||||
ax_left_top.bar(
|
||||
i,
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
ax_left_bottom.set_ylim(0, lower_cap_s)
|
||||
ax_left_top.set_ylim(upper_start_s, ymax_s)
|
||||
|
||||
for i, v in enumerate(values_s):
|
||||
if v <= lower_cap_s:
|
||||
ax_left_bottom.text(
|
||||
i,
|
||||
v + lower_cap_s * 0.02,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
else:
|
||||
ax_left_top.text(
|
||||
i,
|
||||
v + (ymax_s - upper_start_s) * 0.02,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# Hide spines between axes
|
||||
ax_left_top.spines["bottom"].set_visible(False)
|
||||
ax_left_bottom.spines["top"].set_visible(False)
|
||||
ax_left_top.tick_params(
|
||||
labeltop=False, labelbottom=False, bottom=False
|
||||
) # Hide tick marks
|
||||
ax_left_bottom.xaxis.tick_bottom()
|
||||
ax_left_bottom.tick_params(top=False) # Hide top tick marks
|
||||
|
||||
# Draw break marks (matching paper_fig.py style)
|
||||
d = 0.015
|
||||
kwargs = {
|
||||
"transform": ax_left_top.transAxes,
|
||||
"color": "k",
|
||||
"clip_on": False,
|
||||
"linewidth": 0.8,
|
||||
"zorder": 10,
|
||||
}
|
||||
ax_left_top.plot((-d, +d), (-d, +d), **kwargs)
|
||||
ax_left_top.plot((1 - d, 1 + d), (-d, +d), **kwargs)
|
||||
kwargs.update({"transform": ax_left_bottom.transAxes})
|
||||
ax_left_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs)
|
||||
ax_left_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)
|
||||
|
||||
ax_left_bottom.set_xticks(x)
|
||||
ax_left_bottom.set_xticklabels(labels, rotation=0, fontsize=7)
|
||||
# Don't set ylabel here - will use fig.text for alignment
|
||||
ax_left_bottom.tick_params(axis="y", labelsize=10)
|
||||
ax_left_top.tick_params(axis="y", labelsize=10)
|
||||
# Add subtle grid for better readability
|
||||
ax_left_bottom.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_left_top.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_left_top.set_title("Single Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||
|
||||
# Set x-axis limits to match bar width with right subplot
|
||||
ax_left_bottom.set_xlim(-0.6, 3.6)
|
||||
ax_left_top.set_xlim(-0.6, 3.6)
|
||||
|
||||
ax_left = ax_left_bottom # for compatibility
|
||||
else:
|
||||
# Regular side-by-side layout
|
||||
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(8.4, 3.15))
|
||||
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = ax_left.bar(x, show_vals, color=colors[: len(labels)], width=0.8)
|
||||
for i, (val, show) in enumerate(zip(values, show_vals)):
|
||||
if val > cap:
|
||||
bars[i].set_hatch("//")
|
||||
ax_left.text(
|
||||
i, cap * 1.02, _fmt_ms(val), ha="center", va="bottom", fontsize=9
|
||||
)
|
||||
else:
|
||||
ax_left.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
_fmt_ms(val),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax_left.set_ylim(0, cap * 1.10)
|
||||
_add_break_marker(ax_left, y=0.98)
|
||||
ax_left.set_xticks(x)
|
||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
else:
|
||||
ax_left.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
for i, v in enumerate(values):
|
||||
ax_left.text(i, v + 1.0, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
ax_left.set_xticks(x)
|
||||
ax_left.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
ax_left.set_ylabel("Latency (ms per passage)")
|
||||
max_initial = latest_rows[0].get("max_initial", "?")
|
||||
max_updates = latest_rows[0].get("max_updates", "?")
|
||||
ax_left.set_title(
|
||||
f"HNSW RNG (run {latest_run}) | init={max_initial}, upd={max_updates}"
|
||||
)
|
||||
|
||||
# Right subplot (A vs B, seconds) - paper style
|
||||
r_labels = ["Sequential", "Delayed \n Add+Search"]
|
||||
r_values = [a_total or 0.0, b_makespan or 0.0]
|
||||
r_styles = [
|
||||
{"edgecolor": "#59a14f", "hatch": "xxxxx"},
|
||||
{"edgecolor": "#edc948", "hatch": "/////"},
|
||||
]
|
||||
# 2 bars, centered with proper spacing
|
||||
xr = [0, 1]
|
||||
bar_width = 0.50 # Reduced for wider spacing between bars
|
||||
for i, (v, style) in enumerate(zip(r_values, r_styles)):
|
||||
ax_right.bar(
|
||||
xr[i],
|
||||
v,
|
||||
width=bar_width,
|
||||
color="white",
|
||||
edgecolor=style["edgecolor"],
|
||||
hatch=style["hatch"],
|
||||
linewidth=1.2,
|
||||
)
|
||||
for i, v in enumerate(r_values):
|
||||
max_v = max(r_values) if r_values else 1.0
|
||||
offset = max(0.0002, 0.02 * max_v)
|
||||
ax_right.text(
|
||||
xr[i],
|
||||
v + offset,
|
||||
f"{v:.2f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=8,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_right.set_xticks(xr)
|
||||
ax_right.set_xticklabels(r_labels, rotation=0, fontsize=7)
|
||||
# Don't set ylabel here - will use fig.text for alignment
|
||||
ax_right.tick_params(axis="y", labelsize=10)
|
||||
# Add subtle grid for better readability
|
||||
ax_right.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5)
|
||||
ax_right.set_title("Batched Add Operation", fontsize=11, pad=10, fontweight="bold")
|
||||
|
||||
# Set x-axis limits to match left subplot's bar width visually
|
||||
# Accounting for width_ratios=[1.5, 1]:
|
||||
# Left: 4 bars, xlim(-0.6, 3.6), range=4.2, physical_width=1.5*unit
|
||||
# bar_width_visual = 0.72 * (1.5*unit / 4.2)
|
||||
# Right: 2 bars, need same visual width
|
||||
# 0.72 * (1.0*unit / range_right) = 0.72 * (1.5*unit / 4.2)
|
||||
# range_right = 4.2 / 1.5 = 2.8
|
||||
# For bars at 0, 1: padding = (2.8 - 1) / 2 = 0.9
|
||||
ax_right.set_xlim(-0.9, 1.9)
|
||||
|
||||
# Set y-axis limit with headroom for text labels
|
||||
if r_values:
|
||||
max_v = max(r_values)
|
||||
ax_right.set_ylim(0, max_v * 1.15)
|
||||
|
||||
# Format y-axis to avoid scientific notation
|
||||
ax_right.ticklabel_format(style="plain", axis="y")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# Add aligned ylabels using fig.text (after tight_layout)
|
||||
# Get the vertical center of the entire figure
|
||||
fig_center_y = 0.5
|
||||
# Left ylabel - closer to left plot
|
||||
left_x = 0.05
|
||||
fig.text(
|
||||
left_x,
|
||||
fig_center_y,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
# Right ylabel - closer to right plot
|
||||
right_bbox = ax_right.get_position()
|
||||
right_x = right_bbox.x0 - 0.07
|
||||
fig.text(
|
||||
right_x,
|
||||
fig_center_y,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||
# Also save PDF for paper
|
||||
pdf_out = args.out.with_suffix(".pdf")
|
||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||
print(f"Saved: {args.out}")
|
||||
print(f"Saved: {pdf_out}")
|
||||
return
|
||||
|
||||
# Broken-Y mode
|
||||
if args.broken_y:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
fig, (ax_top, ax_bottom) = plt.subplots(
|
||||
2,
|
||||
1,
|
||||
sharex=True,
|
||||
figsize=(7.5, 6.75),
|
||||
gridspec_kw={"height_ratios": [1, 3], "hspace": 0.08},
|
||||
)
|
||||
|
||||
# Determine default breaks from second-highest
|
||||
s = sorted(values, reverse=True)
|
||||
second = s[1] if len(s) >= 2 else (s[0] if s else 0.0)
|
||||
lower_cap = args.lower_cap_y if args.lower_cap_y is not None else second * 1.1
|
||||
upper_start = (
|
||||
args.upper_start_y
|
||||
if args.upper_start_y is not None
|
||||
else max(second * 1.2, lower_cap * 1.02)
|
||||
)
|
||||
ymax = max(values) * 1.10 if values else 1.0
|
||||
|
||||
x = list(range(len(labels)))
|
||||
ax_bottom.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
ax_top.bar(x, values, color=colors[: len(labels)], width=0.8)
|
||||
|
||||
# Limits
|
||||
ax_bottom.set_ylim(0, lower_cap)
|
||||
ax_top.set_ylim(upper_start, ymax)
|
||||
|
||||
# Annotate values
|
||||
for i, v in enumerate(values):
|
||||
if v <= lower_cap:
|
||||
ax_bottom.text(
|
||||
i, v + lower_cap * 0.02, _fmt_ms(v), ha="center", va="bottom", fontsize=9
|
||||
)
|
||||
else:
|
||||
ax_top.text(i, v, _fmt_ms(v), ha="center", va="bottom", fontsize=9)
|
||||
|
||||
# Hide spines between axes and draw diagonal break marks
|
||||
ax_top.spines["bottom"].set_visible(False)
|
||||
ax_bottom.spines["top"].set_visible(False)
|
||||
ax_top.tick_params(labeltop=False) # don't put tick labels at the top
|
||||
ax_bottom.xaxis.tick_bottom()
|
||||
|
||||
# Diagonal lines at the break (matching paper_fig.py style)
|
||||
d = 0.015
|
||||
kwargs = {
|
||||
"transform": ax_top.transAxes,
|
||||
"color": "k",
|
||||
"clip_on": False,
|
||||
"linewidth": 0.8,
|
||||
"zorder": 10,
|
||||
}
|
||||
ax_top.plot((-d, +d), (-d, +d), **kwargs) # top-left diagonal
|
||||
ax_top.plot((1 - d, 1 + d), (-d, +d), **kwargs) # top-right diagonal
|
||||
kwargs.update({"transform": ax_bottom.transAxes})
|
||||
ax_bottom.plot((-d, +d), (1 - d, 1 + d), **kwargs) # bottom-left diagonal
|
||||
ax_bottom.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs) # bottom-right diagonal
|
||||
|
||||
ax_bottom.set_xticks(x)
|
||||
ax_bottom.set_xticklabels(labels, rotation=0, fontsize=10)
|
||||
ax = ax_bottom # for labeling below
|
||||
else:
|
||||
cap = args.cap_y
|
||||
if cap is None and not args.no_auto_cap:
|
||||
cap = _auto_cap(values)
|
||||
|
||||
plt.figure(figsize=(5.4, 3.15))
|
||||
ax = plt.gca()
|
||||
|
||||
if cap is not None:
|
||||
show_vals = [min(v, cap) for v in values]
|
||||
bars = []
|
||||
for i, (_label, val, show) in enumerate(zip(labels, values, show_vals)):
|
||||
bar = ax.bar(i, show, color=colors[i], width=0.8)
|
||||
bars.append(bar[0])
|
||||
# Hatch and annotate when capped
|
||||
if val > cap:
|
||||
bars[-1].set_hatch("//")
|
||||
ax.text(i, cap * 1.02, f"{_fmt_ms(val)}", ha="center", va="bottom", fontsize=9)
|
||||
else:
|
||||
ax.text(
|
||||
i,
|
||||
show + max(1.0, 0.01 * (cap or show)),
|
||||
f"{_fmt_ms(val)}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=9,
|
||||
)
|
||||
ax.set_ylim(0, cap * 1.10)
|
||||
_add_break_marker(ax, y=0.98)
|
||||
ax.legend([bars[1]], ["capped"], fontsize=8, frameon=False, loc="upper right") if any(
|
||||
v > cap for v in values
|
||||
) else None
|
||||
ax.set_xticks(range(len(labels)))
|
||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||
else:
|
||||
ax.bar(labels, values, color=colors[: len(labels)])
|
||||
for idx, val in enumerate(values):
|
||||
ax.text(
|
||||
idx,
|
||||
val + 1.0,
|
||||
f"{_fmt_ms(val)}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax.set_xticklabels(labels, fontsize=11, fontweight="bold")
|
||||
# Try to extract some context for title
|
||||
max_initial = latest_rows[0].get("max_initial", "?")
|
||||
max_updates = latest_rows[0].get("max_updates", "?")
|
||||
|
||||
if args.broken_y:
|
||||
fig.text(
|
||||
0.02,
|
||||
0.5,
|
||||
"Latency (s)",
|
||||
va="center",
|
||||
rotation="vertical",
|
||||
fontsize=11,
|
||||
fontweight="bold",
|
||||
)
|
||||
fig.suptitle(
|
||||
"Add Operation Latency",
|
||||
fontsize=11,
|
||||
y=0.98,
|
||||
fontweight="bold",
|
||||
)
|
||||
plt.tight_layout(rect=(0.03, 0.04, 1, 0.96))
|
||||
else:
|
||||
plt.ylabel("Latency (s)", fontsize=11, fontweight="bold")
|
||||
plt.title("Add Operation Latency", fontsize=11, fontweight="bold")
|
||||
plt.tight_layout()
|
||||
|
||||
plt.savefig(args.out, bbox_inches="tight", pad_inches=0.05)
|
||||
# Also save PDF for paper
|
||||
pdf_out = args.out.with_suffix(".pdf")
|
||||
plt.savefig(pdf_out, bbox_inches="tight", pad_inches=0.05)
|
||||
print(f"Saved: {args.out}")
|
||||
print(f"Saved: {pdf_out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
82
data/.gitattributes
vendored
82
data/.gitattributes
vendored
@@ -1,82 +0,0 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lz4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.mds filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - uncompressed
|
||||
*.pcm filter=lfs diff=lfs merge=lfs -text
|
||||
*.sam filter=lfs diff=lfs merge=lfs -text
|
||||
*.raw filter=lfs diff=lfs merge=lfs -text
|
||||
# Audio files - compressed
|
||||
*.aac filter=lfs diff=lfs merge=lfs -text
|
||||
*.flac filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ogg filter=lfs diff=lfs merge=lfs -text
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - uncompressed
|
||||
*.bmp filter=lfs diff=lfs merge=lfs -text
|
||||
*.gif filter=lfs diff=lfs merge=lfs -text
|
||||
*.png filter=lfs diff=lfs merge=lfs -text
|
||||
*.tiff filter=lfs diff=lfs merge=lfs -text
|
||||
# Image files - compressed
|
||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
||||
*.webp filter=lfs diff=lfs merge=lfs -text
|
||||
# Video files - compressed
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||
ground_truth/dpr/id_map.json filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann.passages.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/dpr_diskann_disk.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/dpr/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/leann.labels.map filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.index filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.0.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.1.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.2.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.3.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.4.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.5.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.6.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.idx filter=lfs diff=lfs merge=lfs -text
|
||||
indices/rpj_wiki/rpj_wiki.passages.7.jsonl filter=lfs diff=lfs merge=lfs -text
|
||||
@@ -1,5 +1,5 @@
|
||||
The Project Gutenberg eBook of Pride and Prejudice
|
||||
|
||||
|
||||
This ebook is for the use of anyone anywhere in the United States and
|
||||
most other parts of the world at no cost and with almost no restrictions
|
||||
whatsoever. You may copy it, give it away or re-use it under the terms
|
||||
@@ -14557,7 +14557,7 @@ her into Derbyshire, had been the means of uniting them.
|
||||
*** END OF THE PROJECT GUTENBERG EBOOK PRIDE AND PREJUDICE ***
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Updated editions will replace the previous one—the old editions will
|
||||
be renamed.
|
||||
@@ -14662,7 +14662,7 @@ performed, viewed, copied or distributed:
|
||||
at www.gutenberg.org. If you
|
||||
are not located in the United States, you will have to check the laws
|
||||
of the country where you are located before using this eBook.
|
||||
|
||||
|
||||
1.E.2. If an individual Project Gutenberg™ electronic work is
|
||||
derived from texts not protected by U.S. copyright law (does not
|
||||
contain a notice indicating that it is posted with permission of the
|
||||
@@ -14724,7 +14724,7 @@ provided that:
|
||||
Gutenberg Literary Archive Foundation at the address specified in
|
||||
Section 4, “Information about donations to the Project Gutenberg
|
||||
Literary Archive Foundation.”
|
||||
|
||||
|
||||
• You provide a full refund of any money paid by a user who notifies
|
||||
you in writing (or by e-mail) within 30 days of receipt that s/he
|
||||
does not agree to the terms of the full Project Gutenberg™
|
||||
@@ -14732,15 +14732,15 @@ provided that:
|
||||
copies of the works possessed in a physical medium and discontinue
|
||||
all use of and all access to other copies of Project Gutenberg™
|
||||
works.
|
||||
|
||||
|
||||
• You provide, in accordance with paragraph 1.F.3, a full refund of
|
||||
any money paid for a work or a replacement copy, if a defect in the
|
||||
electronic work is discovered and reported to you within 90 days of
|
||||
receipt of the work.
|
||||
|
||||
|
||||
• You comply with all other terms of this agreement for free
|
||||
distribution of Project Gutenberg™ works.
|
||||
|
||||
|
||||
|
||||
1.E.9. If you wish to charge a fee or distribute a Project
|
||||
Gutenberg™ electronic work or group of works on different terms than
|
||||
@@ -14903,5 +14903,3 @@ This website includes information about Project Gutenberg™,
|
||||
including how to make donations to the Project Gutenberg Literary
|
||||
Archive Foundation, how to help produce our new eBooks, and how to
|
||||
subscribe to our email newsletter to hear about new eBooks.
|
||||
|
||||
|
||||
277
demo.ipynb
277
demo.ipynb
@@ -4,7 +4,11 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Quick Start in 30s"
|
||||
"# Quick Start \n",
|
||||
"\n",
|
||||
"**Home GitHub Repository:** [LEANN on GitHub](https://github.com/yichuan-w/LEANN)\n",
|
||||
"\n",
|
||||
"**Important for Colab users:** Set your runtime type to T4 GPU for optimal performance. Go to Runtime → Change runtime type → Hardware accelerator → T4 GPU."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -13,8 +17,25 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# install this if you areusing colab\n",
|
||||
"! pip install leann"
|
||||
"# install this if you are using colab\n",
|
||||
"! uv pip install leann-core leann-backend-hnsw --no-deps\n",
|
||||
"! uv pip install leann --no-deps\n",
|
||||
"# For Colab environment, we need to set some environment variables\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"LEANN_LOG_LEVEL\"] = \"INFO\" # Enable more detailed logging"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"INDEX_DIR = Path(\"./\").resolve()\n",
|
||||
"INDEX_PATH = str(INDEX_DIR / \"demo.leann\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -26,91 +47,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/yichuan/Desktop/code/LEANN/leann/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n",
|
||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||
"Writing passages: 100%|██████████| 5/5 [00:00<00:00, 27887.66chunk/s]\n",
|
||||
"Batches: 100%|██████████| 1/1 [00:00<00:00, 13.51it/s]\n",
|
||||
"WARNING:leann_backend_hnsw.hnsw_backend:Converting data to float32, shape: (5, 768)\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Converting HNSW index to CSR-pruned format...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"M: 64 for level: 0\n",
|
||||
"Starting conversion: knowledge.index -> knowledge.csr.tmp\n",
|
||||
"[0.00s] Reading Index HNSW header...\n",
|
||||
"[0.00s] Header read: d=768, ntotal=5\n",
|
||||
"[0.00s] Reading HNSW struct vectors...\n",
|
||||
" Reading vector (dtype=<class 'numpy.float64'>, fmt='d')... Count=6, Bytes=48\n",
|
||||
"[0.00s] Read assign_probas (6)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=7, Bytes=28\n",
|
||||
"[0.11s] Read cum_nneighbor_per_level (7)\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=5, Bytes=20\n",
|
||||
"[0.21s] Read levels (5)\n",
|
||||
"[0.30s] Probing for compact storage flag...\n",
|
||||
"[0.30s] Found compact flag: False\n",
|
||||
"[0.30s] Compact flag is False, reading original format...\n",
|
||||
"[0.30s] Probing for potential extra byte before non-compact offsets...\n",
|
||||
"[0.30s] Found and consumed an unexpected 0x00 byte.\n",
|
||||
" Reading vector (dtype=<class 'numpy.uint64'>, fmt='Q')... Count=6, Bytes=48\n",
|
||||
"[0.30s] Read offsets (6)\n",
|
||||
"[0.40s] Attempting to read neighbors vector...\n",
|
||||
" Reading vector (dtype=<class 'numpy.int32'>, fmt='i')... Count=320, Bytes=1280\n",
|
||||
"[0.40s] Read neighbors (320)\n",
|
||||
"[0.50s] Read scalar params (ep=4, max_lvl=0)\n",
|
||||
"[0.50s] Checking for storage data...\n",
|
||||
"[0.50s] Found storage fourcc: 49467849.\n",
|
||||
"[0.50s] Converting to CSR format...\n",
|
||||
"[0.50s] Conversion loop finished. \n",
|
||||
"[0.50s] Running validation checks...\n",
|
||||
" Checking total valid neighbor count...\n",
|
||||
" OK: Total valid neighbors = 20\n",
|
||||
" Checking final pointer indices...\n",
|
||||
" OK: Final pointers match data size.\n",
|
||||
"[0.50s] Deleting original neighbors and offsets arrays...\n",
|
||||
" CSR Stats: |data|=20, |level_ptr|=10\n",
|
||||
"[0.59s] Writing CSR HNSW graph data in FAISS-compatible order...\n",
|
||||
" Pruning embeddings: Writing NULL storage marker.\n",
|
||||
"[0.69s] Conversion complete.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:✅ CSR conversion successful.\n",
|
||||
"INFO:leann_backend_hnsw.hnsw_backend:INFO: Replaced original index with CSR-pruned version at 'knowledge.index'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannBuilder\n",
|
||||
"\n",
|
||||
"builder = LeannBuilder(backend_name=\"hnsw\")\n",
|
||||
"builder.add_text(\"C# is a powerful programming language and it is good at game development\")\n",
|
||||
"builder.add_text(\"Python is a powerful programming language and it is good at machine learning tasks\")\n",
|
||||
"builder.add_text(\n",
|
||||
" \"Python is a powerful programming language and it is good at machine learning tasks\"\n",
|
||||
")\n",
|
||||
"builder.add_text(\"Machine learning transforms industries\")\n",
|
||||
"builder.add_text(\"Neural networks process complex data\")\n",
|
||||
"builder.add_text(\"Leann is a great storage saving engine for RAG on your MacBook\")\n",
|
||||
"builder.build_index(\"knowledge.leann\")"
|
||||
"builder.build_index(INDEX_PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -122,97 +73,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'programming languages'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Using port 5560 instead of 5557\n",
|
||||
"INFO:leann.embedding_server_manager:Starting embedding server on port 5560...\n",
|
||||
"INFO:leann.embedding_server_manager:Command: /Users/yichuan/Desktop/code/LEANN/leann/.venv/bin/python -m leann_backend_hnsw.hnsw_embedding_server --zmq-port 5560 --model-name facebook/contriever --passages-file knowledge.leann.meta.json\n",
|
||||
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
|
||||
"To disable this warning, you can either:\n",
|
||||
"\t- Avoid using `tokenizers` before the fork if possible\n",
|
||||
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
|
||||
"INFO:leann.embedding_server_manager:Server process started with PID: 4574\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n",
|
||||
"INFO: Registering backend 'hnsw'\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.embedding_server_manager:Embedding server is ready!\n",
|
||||
"INFO:leann.api: Launching server time: 1.078078269958496 seconds\n",
|
||||
"INFO:leann.embedding_server_manager:Existing server process (PID 4574) is compatible\n",
|
||||
"INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: facebook/contriever\n",
|
||||
"WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name facebook/contriever. Creating a new one with mean pooling.\n",
|
||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||
"INFO:leann.api: Embedding time: 2.9307072162628174 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api: Search time: 0.27327895164489746 seconds\n",
|
||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||
"INFO:leann.api: Final enriched results: 2 passages\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[SearchResult(id='0', score=np.float32(0.9874103), text='C# is a powerful programming language and it is good at game development', metadata={}),\n",
|
||||
" SearchResult(id='1', score=np.float32(0.8922168), text='Python is a powerful programming language and it is good at machine learning tasks', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannSearcher\n",
|
||||
"\n",
|
||||
"searcher = LeannSearcher(\"knowledge.leann\")\n",
|
||||
"searcher = LeannSearcher(INDEX_PATH)\n",
|
||||
"results = searcher.search(\"programming languages\", top_k=2)\n",
|
||||
"results"
|
||||
]
|
||||
@@ -228,79 +95,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.chat:Attempting to create LLM of type='hf' with model='Qwen/Qwen3-0.6B'\n",
|
||||
"INFO:leann.chat:Initializing HFChat with model='Qwen/Qwen3-0.6B'\n",
|
||||
"INFO:leann.chat:MPS is available. Using Apple Silicon GPU.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[read_HNSW - CSR NL v4] Reading metadata & CSR indices (manual offset)...\n",
|
||||
"[read_HNSW NL v4] Read levels vector, size: 5\n",
|
||||
"[read_HNSW NL v4] Reading Compact Storage format indices...\n",
|
||||
"[read_HNSW NL v4] Read compact_level_ptr, size: 10\n",
|
||||
"[read_HNSW NL v4] Read compact_node_offsets, size: 6\n",
|
||||
"[read_HNSW NL v4] Read entry_point: 4, max_level: 0\n",
|
||||
"[read_HNSW NL v4] Read storage fourcc: 0x6c6c756e\n",
|
||||
"[read_HNSW NL v4 FIX] Detected FileIOReader. Neighbors size field offset: 326\n",
|
||||
"[read_HNSW NL v4] Reading neighbors data into memory.\n",
|
||||
"[read_HNSW NL v4] Read neighbors data, size: 20\n",
|
||||
"[read_HNSW NL v4] Finished reading metadata and CSR indices.\n",
|
||||
"INFO: Skipping external storage loading, since is_recompute is true.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO:leann.api:🔍 LeannSearcher.search() called:\n",
|
||||
"INFO:leann.api: Query: 'Compare the two retrieved programming languages and tell me their advantages.'\n",
|
||||
"INFO:leann.api: Top_k: 2\n",
|
||||
"INFO:leann.api: Additional kwargs: {}\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5557 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5558 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Port 5559 has incompatible server, trying next port...\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Launching server time: 0.04932403564453125 seconds\n",
|
||||
"INFO:leann.embedding_server_manager:Found compatible server on port 5560\n",
|
||||
"INFO:leann.embedding_server_manager:Using existing compatible server on port 5560\n",
|
||||
"INFO:leann.api: Generated embedding shape: (1, 768)\n",
|
||||
"INFO:leann.api: Embedding time: 0.06902289390563965 seconds\n",
|
||||
"INFO:leann.api: Search time: 0.026793241500854492 seconds\n",
|
||||
"INFO:leann.api: Backend returned: labels=2 results\n",
|
||||
"INFO:leann.api: Processing 2 passage IDs:\n",
|
||||
"INFO:leann.api: 1. passage_id='0' -> SUCCESS: C# is a powerful programming language and it is good at game development...\n",
|
||||
"INFO:leann.api: 2. passage_id='1' -> SUCCESS: Python is a powerful programming language and it is good at machine learning tasks...\n",
|
||||
"INFO:leann.api: Final enriched results: 2 passages\n",
|
||||
"INFO:leann.chat:Generating with HuggingFace model, config: {'max_new_tokens': 128, 'temperature': 0.7, 'top_p': 0.9, 'do_sample': True, 'pad_token_id': 151645, 'eos_token_id': 151645}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ZmqDistanceComputer initialized: d=768, metric=0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"<think>\\n\\n</think>\\n\\nBased on the context provided, here's a comparison of the two retrieved programming languages:\\n\\n**C#** is known for being a powerful programming language and is well-suited for game development. It is often used in game development and is popular among developers working on Windows applications.\\n\\n**Python**, on the other hand, is also a powerful language and is well-suited for machine learning tasks. It is widely used for data analysis, scientific computing, and other applications that require handling large datasets or performing complex calculations.\\n\\n**Advantages**:\\n- C#: Strong for game development and cross-platform compatibility.\\n- Python: Strong for\""
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from leann.api import LeannChat\n",
|
||||
"\n",
|
||||
@@ -309,11 +104,11 @@
|
||||
" \"model\": \"Qwen/Qwen3-0.6B\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"chat = LeannChat(index_path=\"knowledge.leann\", llm_config=llm_config)\n",
|
||||
"chat = LeannChat(index_path=INDEX_PATH, llm_config=llm_config)\n",
|
||||
"response = chat.ask(\n",
|
||||
" \"Compare the two retrieved programming languages and tell me their advantages.\",\n",
|
||||
" top_k=2,\n",
|
||||
" llm_kwargs={\"max_tokens\": 128}\n",
|
||||
" llm_kwargs={\"max_tokens\": 128},\n",
|
||||
")\n",
|
||||
"response"
|
||||
]
|
||||
|
||||
200
docs/COLQWEN_GUIDE.md
Normal file
200
docs/COLQWEN_GUIDE.md
Normal file
@@ -0,0 +1,200 @@
|
||||
# ColQwen Integration Guide
|
||||
|
||||
Easy-to-use multimodal PDF retrieval with ColQwen2/ColPali models.
|
||||
|
||||
## Quick Start
|
||||
|
||||
> **🍎 Mac Users**: ColQwen is optimized for Apple Silicon with MPS acceleration for faster inference!
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
brew install poppler # macOS only, for PDF processing
|
||||
```
|
||||
|
||||
### 2. Basic Usage
|
||||
```bash
|
||||
# Build index from PDFs
|
||||
python -m apps.colqwen_rag build --pdfs ./my_papers/ --index research_papers
|
||||
|
||||
# Search with text queries
|
||||
python -m apps.colqwen_rag search research_papers "How does attention mechanism work?"
|
||||
|
||||
# Interactive Q&A
|
||||
python -m apps.colqwen_rag ask research_papers --interactive
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
### Build Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag build \
|
||||
--pdfs ./pdf_directory/ \
|
||||
--index my_index \
|
||||
--model colqwen2 \
|
||||
--pages-dir ./page_images/ # Optional: save page images
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--pdfs`: Directory containing PDF files (or single PDF path)
|
||||
- `--index`: Name for the index (required)
|
||||
- `--model`: `colqwen2` (default) or `colpali`
|
||||
- `--pages-dir`: Directory to save page images (optional)
|
||||
|
||||
### Search Index
|
||||
```bash
|
||||
python -m apps.colqwen_rag search my_index "your question here" --top-k 5
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- `--top-k`: Number of results to return (default: 5)
|
||||
- `--model`: Model used for search (should match build model)
|
||||
|
||||
### Interactive Q&A
|
||||
```bash
|
||||
python -m apps.colqwen_rag ask my_index --interactive
|
||||
```
|
||||
|
||||
**Commands in interactive mode:**
|
||||
- Type your questions naturally
|
||||
- `help`: Show available commands
|
||||
- `quit`/`exit`/`q`: Exit interactive mode
|
||||
|
||||
## 🧪 Test & Reproduce Results
|
||||
|
||||
Run the reproduction test for issue #119:
|
||||
```bash
|
||||
python test_colqwen_reproduction.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. ✅ Check dependencies
|
||||
2. 📥 Download sample PDF (Attention Is All You Need paper)
|
||||
3. 🏗️ Build test index
|
||||
4. 🔍 Run sample queries
|
||||
5. 📊 Show how to generate similarity maps
|
||||
|
||||
## 🎨 Advanced: Similarity Maps
|
||||
|
||||
For visual similarity analysis, use the existing advanced script:
|
||||
```bash
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
python multi-vector-leann-similarity-map.py
|
||||
```
|
||||
|
||||
Edit the script to customize:
|
||||
- `QUERY`: Your question
|
||||
- `MODEL`: "colqwen2" or "colpali"
|
||||
- `USE_HF_DATASET`: Use HuggingFace dataset or local PDFs
|
||||
- `SIMILARITY_MAP`: Generate heatmaps
|
||||
- `ANSWER`: Enable Qwen-VL answer generation
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
### ColQwen2 vs ColPali
|
||||
- **ColQwen2** (`vidore/colqwen2-v1.0`): Latest vision-language model
|
||||
- **ColPali** (`vidore/colpali-v1.2`): Proven multimodal retriever
|
||||
|
||||
### Architecture
|
||||
1. **PDF → Images**: Convert PDF pages to images (150 DPI)
|
||||
2. **Vision Encoding**: Process images with ColQwen2/ColPali
|
||||
3. **Multi-Vector Index**: Build LEANN HNSW index with multiple embeddings per page
|
||||
4. **Query Processing**: Encode text queries with same model
|
||||
5. **Similarity Search**: Find most relevant pages/regions
|
||||
6. **Visual Maps**: Generate attention heatmaps (optional)
|
||||
|
||||
### Device Support
|
||||
- **CUDA**: Best performance with GPU acceleration
|
||||
- **MPS**: Apple Silicon Mac support
|
||||
- **CPU**: Fallback for any system (slower)
|
||||
|
||||
Auto-detection: CUDA > MPS > CPU
|
||||
|
||||
## 📊 Performance Tips
|
||||
|
||||
### For Best Performance:
|
||||
```bash
|
||||
# Use ColQwen2 for latest features
|
||||
--model colqwen2
|
||||
|
||||
# Save page images for reuse
|
||||
--pages-dir ./cached_pages/
|
||||
|
||||
# Adjust batch size based on GPU memory
|
||||
# (automatically handled)
|
||||
```
|
||||
|
||||
### For Large Document Sets:
|
||||
- Process PDFs in batches
|
||||
- Use SSD storage for index files
|
||||
- Consider using CUDA if available
|
||||
|
||||
## 🔗 Related Resources
|
||||
|
||||
- **Fast-PLAID**: https://github.com/lightonai/fast-plaid
|
||||
- **Pylate**: https://github.com/lightonai/pylate
|
||||
- **ColBERT**: https://github.com/stanford-futuredata/ColBERT
|
||||
- **ColPali Paper**: Vision-Language Models for Document Retrieval
|
||||
- **Issue #119**: https://github.com/yichuan-w/LEANN/issues/119
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### PDF Conversion Issues (macOS)
|
||||
```bash
|
||||
# Install poppler
|
||||
brew install poppler
|
||||
which pdfinfo && pdfinfo -v
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
- Reduce batch size (automatically handled)
|
||||
- Use CPU instead of GPU: `export CUDA_VISIBLE_DEVICES=""`
|
||||
- Process fewer PDFs at once
|
||||
|
||||
### Model Download Issues
|
||||
- Ensure internet connection for first run
|
||||
- Models are cached after first download
|
||||
- Use HuggingFace mirrors if needed
|
||||
|
||||
### Import Errors
|
||||
```bash
|
||||
# Ensure all dependencies installed
|
||||
uv pip install colpali_engine pdf2image pillow matplotlib qwen_vl_utils einops seaborn
|
||||
|
||||
# Check PyTorch installation
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
```
|
||||
|
||||
## 💡 Examples
|
||||
|
||||
### Research Paper Analysis
|
||||
```bash
|
||||
# Index your research papers
|
||||
python -m apps.colqwen_rag build --pdfs ~/Papers/AI/ --index ai_papers
|
||||
|
||||
# Ask research questions
|
||||
python -m apps.colqwen_rag search ai_papers "What are the limitations of transformer models?"
|
||||
python -m apps.colqwen_rag search ai_papers "How does BERT compare to GPT?"
|
||||
```
|
||||
|
||||
### Document Q&A
|
||||
```bash
|
||||
# Index business documents
|
||||
python -m apps.colqwen_rag build --pdfs ~/Documents/Reports/ --index reports
|
||||
|
||||
# Interactive analysis
|
||||
python -m apps.colqwen_rag ask reports --interactive
|
||||
```
|
||||
|
||||
### Visual Analysis
|
||||
```bash
|
||||
# Generate similarity maps for specific queries
|
||||
cd apps/multimodal/vision-based-pdf-multi-vector/
|
||||
# Edit multi-vector-leann-similarity-map.py with your query
|
||||
python multi-vector-leann-similarity-map.py
|
||||
# Check ./figures/ for generated heatmaps
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**🎯 This integration makes ColQwen as easy to use as other LEANN features while maintaining the full power of multimodal document understanding!**
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user